未验证 提交 36c6c589 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Support the 'drop_empty_grad' in of output of backward_ops (#49588)

* support the drop_empty_grad in backward

* change code according to yunfei's review suggestion
上级 31ea3231
......@@ -459,6 +459,27 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
)
def parse_drop_empty_grad(op_fluid_list: list, bw_op_dict: dict):
for op_op in op_fluid_list:
if 'drop_empty_grad' in op_op:
bw_names = [
bw_name.split('(')[0].strip()
for bw_name in op_op['backward'].split(',')
]
for bw_name in bw_names:
assert (
bw_name in bw_op_dict
), f"backward {bw_name} is not existed"
for out_grad in op_op['drop_empty_grad']:
assert (
out_grad in bw_op_dict[bw_name]['output_dict']
), f'''
{bw_name} with {out_grad} is not existed in output_dict '''
bw_op_dict[bw_name]['output_dict'][out_grad][
'drop_empty_grad'
] = False
def main(
ops_yaml_path,
backward_yaml_path,
......@@ -488,6 +509,11 @@ def main(
op['op_name'] = op['name']
for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name']
for bw_output in bw_op['outputs']:
bw_output['drop_empty_grad'] = True
# deal the drop_empty_grad of bw_op by op_compat.yaml
parse_drop_empty_grad(op_fluid_map_list, backward_op_dict)
parse_composite_info(ops, backward_ops, backward_op_dict)
......
......@@ -417,7 +417,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
forward_output_orig_names,
output['drop_empty_grad'])}});
{% endfor %}
grad_op->SetAttrMap(this->Attrs());
......@@ -480,7 +481,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
forward_output_orig_names,
true)}});
{% endfor %}
{% for attr in invoke_op["attrs"] %}
......@@ -659,11 +661,15 @@ OutputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- endmacro %}
{% macro extract_output_from_forward(name, input_names, output_names,
input_orig_names, output_orig_names) %}{# inline #}
input_orig_names, output_orig_names, drop_empty_grad) %}{# inline #}
{% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
{%- if drop_empty_grad is true -%}
InputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- else -%}
InputGrad({{name_in_forward_orig | to_opmaker_name}}, false)
{%- endif %}
{%- elif (name) in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input({{name | to_opmaker_name}})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class MeshgridOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto inputs = ctx.MultiInput<phi::DenseTensor>("X");
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto* input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Meshgrid OP are Empty!"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class MeshgridOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor, default Tensor<float>).").AsDuplicable();
AddOutput("Out", "(Tensor, default Tensor<float>.)").AsDuplicable();
AddComment(R"DOC(
Meshgrid Operator.
Take: N tensors, each of which can be either scalr or 1-dimensional vector, and create
N-dimensional grids.
Args:
tensors (list of tensor): if the input k tensors has (N1,), (N2,),..., (Nk,), then
the output tensors are all of size (N1, N2, ...., Nk).
Example::
>>> x = fluid.data(name='x', shape=[10], dtype='float64')
>>> y = fluid.data(name='y', shape=[20], dtype='float64')
>>> grid_x, grid_y = fluid.layers.meshgrid([x, y])
>>> grid_x.shape
(10,20)
>>> grid_y.shape
(10,20)
)DOC");
}
};
class MeshgridGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Out")).size(),
1,
platform::errors::InvalidArgument(
"Number of Inputs(Out@Grad) should be larger than 1."
"But received Inputs(Out@Grad)' size = %d .",
ctx->Inputs(framework::GradVarName("Out")).size()));
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context().GetPlace());
}
};
template <typename T>
class MeshgridGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("meshgrid_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(meshgrid,
MeshgridInferShapeFunctor,
PD_INFER_META(phi::MeshgridInferMeta));
REGISTER_OPERATOR(meshgrid,
ops::MeshgridOp,
ops::MeshgridOpMaker,
ops::MeshgridGradOpMaker<paddle::framework::OpDesc>,
ops::MeshgridGradOpMaker<paddle::imperative::OpBase>,
MeshgridInferShapeFunctor);
REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensors of multi_dot operator.").AsDuplicable();
AddOutput("Out", "The output tensor of multi_dot operator");
AddComment(R"DOC(
Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order.
multi_dot chains MatMul and uses optimal parenthesization of the matrices [1] [2]. Depending on the shapes of the matrices, this can speed up the multiplication a lot.
If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.
)DOC");
}
};
class MultiDotOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class MultiDotOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "multi_dot");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"multi_dot");
auto in_x = "X";
auto out_x_g_n = framework::GradVarName(in_x);
auto ins_dims = ctx->GetInputsDim(in_x);
ctx->SetOutputsDim(out_x_g_n, ins_dims);
ctx->ShareAllLoD(in_x, out_x_g_n);
}
};
template <typename T>
class MultiDotOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("multi_dot_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
}
};
template <typename T>
class MultiDotOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("multi_dot");
grad_op->SetInput("X", this->Input(("X")));
grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
grad_op->SetOutput("DDx", this->OutputGrad(framework::GradVarName("X")));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multi_dot,
MultiDotInferShapeFunctor,
PD_INFER_META(phi::MultiDotInferMeta));
REGISTER_OPERATOR(multi_dot,
ops::MultiDotOp,
ops::MultiDotOpMaker,
ops::MultiDotOpGradMaker<paddle::framework::OpDesc>,
ops::MultiDotOpGradMaker<paddle::imperative::OpBase>,
MultiDotInferShapeFunctor);
REGISTER_OPERATOR(multi_dot_grad,
ops::MultiDotOpGrad,
ops::MultiDotOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MultiDotOpDoubleGradMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class MultiplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"Tensor<int32>, index variable which is a 2-D tensor with shape "
"[M, 1] where M is the batch size.");
AddInput("X",
"A list of variables to gather from. All variables have the same "
"shape and the rank is at least 2.")
.AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(
Referring to the given index variable, this layer selects rows from the
input variables to construct a multiplex variable. Assuming that there are
:math:`m` input variables and :math:`I_i` represents the i-th input
variable and :math:`i` is in [0, :math:`m`). All input variables are
tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`].
Please note that rank of the input tensor should be at least 2. Each input
variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`]
where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2`
* ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input
variable. The given index variable should be a 2-D tensor with shape
[:math:`M`, 1]. Let `ID[i]` be the i-th index value of the index variable.
Then the output variable will be a tensor with shape [:math:`d_0`,
:math:`d_1`, ..., :math:`d_R`]. If we treat the output tensor as a 2-D
matrix with shape [:math:`M`, :math:`N`] and let :math:`O[i]` be the i-th
row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`.
* Ids: the index tensor.
* X[0 : N - 1]: the candidate tensors for output (N >= 2).
* For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (Ids[i])-th tensor.
For i-th row of the output tensor:
$$
y[i] = x_{k}[i]
$$
where $y$ is the output tensor, $x_{k}$ is the k-th input tensor,
and $k = Ids[i]$.
)DOC");
}
};
class MultiplexGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto dxs = ctx->Outputs(framework::GradVarName("X"));
PADDLE_ENFORCE_NE(dxs.empty(),
true,
platform::errors::InvalidArgument(
"Output(X@Grad) should not be null."));
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"MultiplexGrad");
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputsDim(framework::GradVarName("X"),
std::vector<framework::DDim>(dxs.size(), dout_dim));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context().GetPlace());
}
};
template <typename T>
class MultiplexGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("multiplex_grad");
op->SetInput("Ids", this->Input("Ids"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multiplex,
MultiplexInferShapeFunctor,
PD_INFER_META(phi::MultiplexInferMeta));
REGISTER_OPERATOR(multiplex,
ops::MultiplexOp,
ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>,
ops::MultiplexGradMaker<paddle::imperative::OpBase>,
MultiplexInferShapeFunctor);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
......@@ -841,6 +841,16 @@
kernel :
func : maxout_grad
- backward_op : meshgrid_grad
forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs)
args : (Tensor[] inputs, Tensor[] outputs_grad)
output : Tensor[](inputs_grad){inputs.size()}
infer_meta :
func : MeshgridGradInferMeta
kernel :
func : meshgrid_grad
data_type : outputs_grad
- backward_op : mode_grad
forward : mode(Tensor x, int axis = -1, bool keepdim = false) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, bool keepdim)
......@@ -851,6 +861,27 @@
kernel :
func : mode_grad
- backward_op : multi_dot_grad
forward : multi_dot (Tensor[] x) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad)
output : Tensor[](x_grad) {x.size()}
infer_meta :
func : MultiDotGradInferMeta
kernel :
func : multi_dot_grad
- backward_op : multiplex_grad
forward : multiplex (Tensor[] inputs, Tensor index) -> Tensor(out)
args : (Tensor[] inputs, Tensor index, Tensor out_grad)
output : Tensor[](inputs_grad){inputs.size()}
infer_meta :
func : MultiplexGradInferMeta
param : [index, out_grad]
kernel :
func : multiplex_grad
param : [index, out_grad]
data_type : out_grad
- backward_op : mv_grad
forward : mv (Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
......
......@@ -806,15 +806,6 @@
backward : mean_double_grad
no_need_buffer : x
- backward_op : meshgrid_grad
forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs)
args : (Tensor[] inputs, Tensor[] outputs_grad)
output : Tensor[](inputs_grad){inputs.size()}
infer_meta :
func : MeshgridGradInferMeta
kernel :
func : meshgrid_grad
- backward_op : min_grad
forward: min (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
......@@ -846,26 +837,6 @@
func : mish_grad
inplace : (out_grad -> x_grad)
- backward_op : multi_dot_grad
forward : multi_dot (Tensor[] x) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad)
output : Tensor[](x_grad) {x.size()}
infer_meta :
func : MultiDotGradInferMeta
kernel :
func : multi_dot_grad
- backward_op : multiplex_grad
forward : multiplex (Tensor[] inputs, Tensor index) -> Tensor(out)
args : (Tensor[] inputs, Tensor index, Tensor out_grad)
output : Tensor[](inputs_grad){inputs.size()}
infer_meta :
func : MultiplexGradInferMeta
param : [index, out_grad]
kernel :
func : multiplex_grad
param : [index, out_grad]
- backward_op : multiply_double_grad
forward : multiply_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
......
......@@ -1177,15 +1177,6 @@
data_type : param
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : meshgrid
args : (Tensor[] inputs)
output : Tensor[]{inputs.size()}
infer_meta :
func : MeshgridInferMeta
kernel :
func : meshgrid
backward : meshgrid_grad
- op : min
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
......@@ -1225,15 +1216,6 @@
optional : master_param
inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out)
- op : multi_dot
args : (Tensor[] x)
output : Tensor
infer_meta :
func : MultiDotInferMeta
kernel :
func : multi_dot
backward : multi_dot_grad
- op : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
......@@ -1243,16 +1225,6 @@
func : multiclass_nms3
optional : rois_num
- op : multiplex
args : (Tensor[] inputs, Tensor index)
output : Tensor
infer_meta :
func : MultiplexInferMeta
kernel :
func : multiplex
data_type : inputs
backward : multiplex_grad
- op : multiply
args : (Tensor x, Tensor y)
output : Tensor
......
......@@ -919,6 +919,14 @@
outputs :
out : Out
- op : meshgrid
backward : meshgrid_grad
inputs :
inputs : X
outputs :
out : Out
drop_empty_grad : [inputs_grad]
- op : mish
backward : mish_grad
extra :
......@@ -931,6 +939,14 @@
outputs :
{out : Out, indices : Indices}
- op : multi_dot
backward : multi_dot_grad
inputs :
x : X
outputs :
out : Out
drop_empty_grad : [x_grad]
- op : multinomial
inputs :
{x : X}
......@@ -941,6 +957,14 @@
data_type : int
support_tensor : true
- op : multiplex
backward : multiplex_grad
inputs :
{inputs : X, index : Ids}
outputs :
out : Out
drop_empty_grad : [inputs_grad]
- op : multiply (elementwise_mul)
backward : multiply_grad (elementwise_mul_grad)
extra :
......
......@@ -795,6 +795,16 @@
func : maxout
backward : maxout_grad
- op : meshgrid
args : (Tensor[] inputs)
output : Tensor[]{inputs.size()}
infer_meta :
func : MeshgridInferMeta
kernel :
func : meshgrid
data_type : inputs
backward : meshgrid_grad
- op : mode
args : (Tensor x, int axis = -1, bool keepdim = false)
output : Tensor(out), Tensor(indices)
......@@ -804,6 +814,15 @@
func : mode
backward : mode_grad
- op : multi_dot
args : (Tensor[] x)
output : Tensor
infer_meta :
func : MultiDotInferMeta
kernel :
func : multi_dot
backward : multi_dot_grad
- op : multinomial
args : (Tensor x, Scalar(int) num_samples = 1, bool replacement = false)
output : Tensor(out)
......@@ -813,6 +832,16 @@
func : multinomial
data_type : x
- op : multiplex
args : (Tensor[] inputs, Tensor index)
output : Tensor
infer_meta :
func : MultiplexInferMeta
kernel :
func : multiplex
data_type : inputs
backward : multiplex_grad
- op : mv
args : (Tensor x, Tensor vec)
output : Tensor
......
......@@ -658,8 +658,10 @@ void MultiplexGradInferMeta(const MetaTensor& ids,
errors::InvalidArgument("Output(X@Grad) should not be null."));
auto dout_dim = out_grad.dims();
for (auto in_grad : ins_grad) {
if (in_grad != nullptr) {
in_grad->set_dims(dout_dim);
}
}
}
void NanmedianGradInferMeta(const MetaTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MeshgridOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("meshgrid", {"X"}, {}, {"Out"});
}
KernelSignature MeshgridGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("meshgrid_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(meshgrid, phi::MeshgridOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(meshgrid_grad, phi::MeshgridGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MultiDotGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multi_dot_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(multi_dot_grad, phi::MultiDotGradOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MultiplexOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("multiplex", {"X", "Ids"}, {}, {"Out"});
}
KernelSignature MultiplexGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiplex_grad", {"Ids", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(multiplex, phi::MultiplexOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(multiplex_grad, phi::MultiplexGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册