未验证 提交 4c0d46a8 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

Generate static graph code of some ops by yaml (#48771)

* generate static graph code of some ops by yaml, test = develop

* fix 'take_along_axis' yaml style

* reset scatter/scatter_nd_add

* delete the comments of put_along_axis
上级 9455d146
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class LogLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Predicted",
"The input value (Predicted) of Log loss op."
"Predicted is a 2-D tensor with shape [batch_size, 1].");
AddInput("Labels",
"The target value (Labels) of Log loss op."
"Labels is a 2-D tensor with shape [batch_size, 1].");
AddOutput("Loss",
"The output tensor with shape [batch_size, 1] "
"which represents the log loss.");
AddAttr<AttrType>("epsilon", "Epsilon in log loss.");
AddComment(R"DOC(
LogLoss Operator.
Log loss is a loss function used for binary classification. Log Loss quantifies
the accuracy of a classifier by penalising false classifications. Minimising the
Log Loss is equivalent to maximising the accuracy of the classifier. We define
Predicted as the values predicted by our model and Labels as the target ground
truth value. Log loss can evaluate how close the predicted values are to the
target. The shapes of Predicted and Labels are both [batch_size, 1].
The equation is:
$$
Loss = - Labels * log(Predicted + \epsilon) -
(1 - Labels) * log(1 - Predicted + \epsilon)
$$
)DOC");
}
};
class LogLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Predicted"), "Input", "Predicted", "LogLossGrad");
OP_INOUT_CHECK(ctx->HasInput("Labels"), "Input", "Labels", "LogLossGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Loss")),
"Input",
framework::GradVarName("Loss"),
"LogLossGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Predicted")),
"Output",
framework::GradVarName("Predicted"),
"LogLossGrad");
auto pred_dims = ctx->GetInputDim("Predicted");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims,
pred_dims,
platform::errors::InvalidArgument(
"The dimensions of loss_grad must be equal to the "
"dimensions of Predicted,"
"But received dimensions of loss_grad is [%s], "
"received Predicted is "
"[%s]",
loss_grad_dims,
pred_dims));
auto pred_grad_name = framework::GradVarName("Predicted");
ctx->SetOutputDim(pred_grad_name, pred_dims);
}
};
template <typename T>
class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("log_loss_grad");
op->SetInput("Predicted", this->Input("Predicted"));
op->SetInput("Labels", this->Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Predicted"),
this->InputGrad("Predicted"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(log_loss,
LogLossInferShapeFunctor,
PD_INFER_META(phi::LogLossInferMeta));
REGISTER_OPERATOR(log_loss,
ops::LogLossOp,
ops::LogLossOpMaker<float>,
ops::LogLossGradMaker<paddle::framework::OpDesc>,
ops::LogLossGradMaker<paddle::imperative::OpBase>,
LogLossInferShapeFunctor);
REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle {
namespace operators {
class PutAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor of PutAlongAxisOp");
AddInput("Index", "The index tensor of PutAlongAxisOp");
AddInput("Value", "The value tensor of PutAlongAxisOp");
AddOutput("Result", "The result tensor of PutAlongAxisOp");
AddAttr<int>("Axis", "The axis that we do PutAlongAxis operation");
AddAttr<std::string>("Reduce", "The reduce operation for scatter")
.SetDefault("assign");
AddComment(R"DOC(
PutAlongAxis Operator.)
)DOC");
}
};
class PutAlongAxisGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
};
template <typename T>
class PutAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("put_along_axis_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Input", this->Input("Input"));
op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Value"), this->InputGrad("Value"));
op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(PutAlongAxisInplaceInferer, {"Input", "Result"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(put_along_axis,
PutAlongAxisInferShapeFunctor,
PD_INFER_META(phi::PutAlongAxisInferMeta));
REGISTER_OPERATOR(put_along_axis,
ops::PutAlongAxisOp,
ops::PutAlongAxisOpMaker,
ops::PutAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::PutAlongAxisGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::PutAlongAxisInplaceInferer,
PutAlongAxisInferShapeFunctor);
REGISTER_OPERATOR(put_along_axis_grad, ops::PutAlongAxisGradOp);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class SearchSortedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SortedSequence");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("SortedSequence",
"(Tensor), N-D or 1-D tensor, The value of the tensor"
"monotonically increases in the innermost dimension.");
AddInput("Values", "(Tensor), N-D tensor given values.");
AddOutput("Out", "(Tensor), The output tensor of searchsorted op.");
AddAttr<bool>("out_int32",
"the output tensor is int64 type if False and on the"
"contrary for int32")
.SetDefault(false);
AddAttr<bool>(
"right",
"corresponding to lower bound if False and upper bound if True")
.SetDefault(false);
AddComment(R"DOC(
Searchsorted Operator.
This OP is used to find the index of the corresponding sorted_sequence in the innermost dimension based on the given values.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(searchsorted,
SearchsortedInferShapeFunctor,
PD_INFER_META(phi::SearchsortedInferMeta));
REGISTER_OPERATOR(searchsorted,
ops::SearchSortedOp,
ops::SearchSortedOpMaker,
SearchsortedInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class SvdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class SvdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of svd op.");
AddOutput("U", "(Tensor), The output U tensor of svd op.");
AddOutput("S", "(Tensor), The output S tensor of svd op.");
AddOutput("VH", "(Tensor), The output VH tensor of svd op.");
AddAttr<bool>("full_matrices",
"(bool, default false) Only Compute the thin U and V"
"when set as True, the gradient have some random "
"attribute.")
.SetDefault(false);
AddComment(R"DOC(
Svd Operator.
This operator is used to perform SVD operation for batched matrics $X$.
$$U, S, VH = svd(X)$$
)DOC");
}
};
class SvdGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")),
"Input",
"U@Grad",
"SvdGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("VH")),
"Input",
"VH@Grad",
"SvdGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("S")),
"Input",
"S@Grad",
"SvdGrad");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SvdGrad");
OP_INOUT_CHECK(ctx->HasInput("S"), "Input", "S", "SvdGrad");
OP_INOUT_CHECK(ctx->HasInput("VH"), "Input", "VH", "SvdGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"SvdGrad");
auto d_x = ctx->GetInputDim(("X"));
ctx->SetOutputDim(framework::GradVarName("X"), d_x);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
template <typename T>
class SvdGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("svd_grad");
retv->SetInput(framework::GradVarName("U"), this->OutputGrad("U"));
retv->SetInput(framework::GradVarName("VH"), this->OutputGrad("VH"));
retv->SetInput(framework::GradVarName("S"), this->OutputGrad("S"));
retv->SetInput("U", this->Output("U"));
retv->SetInput("VH", this->Output("VH"));
retv->SetInput("S", this->Output("S"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(svd,
SvdInferShapeFunctor,
PD_INFER_META(phi::SvdInferMeta));
REGISTER_OPERATOR(svd,
ops::SvdOp,
ops::SvdOpMaker,
ops::SvdGradMaker<paddle::framework::OpDesc>,
ops::SvdGradMaker<paddle::imperative::OpBase>,
SvdInferShapeFunctor);
REGISTER_OPERATOR(svd_grad, ops::SvdGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class TakeAlongAxisOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class TakeAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The input tensor of TakeAlongAxisOp");
AddInput("Index", "The index tensor of TakeAlongAxisOp");
AddOutput("Result", "The result tensor of TakeAlongAxisOp");
AddAttr<int>("Axis",
"The Tensor which contains the axis that we do TakeAlongAxis "
"operation.");
AddComment(R"DOC(
Take_along_axis Operator.)
)DOC");
}
};
class TakeAlongAxisGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
};
template <typename T>
class TakeAlongAxisGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("take_along_axis_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("Input", this->Input("Input"));
op->SetInput(framework::GradVarName("Result"), this->OutputGrad("Result"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(take_along_axis,
TakeAlongAxisInferShapeFunctor,
PD_INFER_META(phi::TakeAlongAxisInferMeta));
REGISTER_OPERATOR(take_along_axis,
ops::TakeAlongAxisOp,
ops::TakeAlongAxisOpMaker,
ops::TakeAlongAxisGradOpMaker<paddle::framework::OpDesc>,
ops::TakeAlongAxisGradOpMaker<paddle::imperative::OpBase>,
TakeAlongAxisInferShapeFunctor);
REGISTER_OPERATOR(take_along_axis_grad, ops::TakeAlongAxisGradOp);
......@@ -676,6 +676,16 @@
backward : log_double_grad
inplace : (out_grad -> x_grad)
- backward_op : log_loss_grad
forward : log_loss (Tensor input, Tensor label, float epsilon) -> Tensor(out)
args : (Tensor input, Tensor label, Tensor out_grad, float epsilon)
output : Tensor(input_grad)
infer_meta :
func : UnchangedInferMeta
param : [input]
kernel :
func : log_loss_grad
- backward_op : logit_grad
forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float eps)
......@@ -779,6 +789,16 @@
kernel :
func : poisson_grad
- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor value, int axis, str reduce = "assign") -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
output : Tensor(arr_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [arr, indices]
kernel :
func : put_along_axis_grad
- backward_op : qr_grad
forward : qr (Tensor x, str mode = "reduced") -> Tensor(q), Tensor(r)
args : (Tensor x, Tensor q, Tensor r, Tensor q_grad, Tensor r_grad, str mode)
......@@ -1062,6 +1082,27 @@
backward : square_double_grad
inplace : (out_grad -> x_grad)
- backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : svd_grad
optional: u_grad, vh_grad, s_grad
- backward_op : take_along_axis_grad
forward : take_along_axis (Tensor arr, Tensor indices, int axis) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis)
output : Tensor(arr_grad)
infer_meta :
func : UnchangedInferMeta
param : [arr]
kernel :
func : take_along_axis_grad
- backward_op : tan_grad
forward : tan (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -745,16 +745,6 @@
func : linear_interp_grad
data_type : output_grad
- backward_op : log_loss_grad
forward : log_loss (Tensor input, Tensor label, float epsilon) -> Tensor(out)
args : (Tensor input, Tensor label, Tensor out_grad, float epsilon)
output : Tensor(input_grad)
infer_meta :
func : UnchangedInferMeta
param : [input]
kernel :
func : log_loss_grad
- backward_op : log_softmax_grad
forward : log_softmax(Tensor x, int axis) -> Tensor(out)
args : (Tensor out, Tensor out_grad, int axis)
......@@ -1195,17 +1185,6 @@
data_type : x
optional : boxes_num
# output is optional
- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor value, int axis, str reduce) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
output : Tensor(arr_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [arr, indices]
kernel :
func : put_along_axis_grad
- backward_op : real_grad
forward : real (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......@@ -1573,17 +1552,6 @@
no_need_buffer : x
backward : sum_double_grad
- backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : svd_grad
optional: u_grad, vh_grad, s_grad
- backward_op : swish_grad
forward : swish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0)
......@@ -1607,16 +1575,6 @@
data_type : out_grad
optional : reserve_space
- backward_op : take_along_axis_grad
forward : take_along_axis (Tensor arr, Tensor indices, int axis) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis)
output : Tensor(arr_grad)
infer_meta :
func : UnchangedInferMeta
param : [arr]
kernel :
func : take_along_axis_grad
- backward_op : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str)
......
......@@ -1068,15 +1068,6 @@
data_type : dtype
backend : place
- op : log_loss
args : (Tensor input, Tensor label, float epsilon)
output : Tensor
infer_meta :
func : LogLossInferMeta
kernel :
func : log_loss
backward : log_loss_grad
- op : log_softmax
args : (Tensor x, int axis)
output : Tensor(out)
......@@ -1555,18 +1546,6 @@
optional : boxes_num
backward : psroi_pool_grad
- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [arr]
kernel :
func : put_along_axis
data_type : arr
inplace : (arr -> out)
backward : put_along_axis_grad
- op : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
output : Tensor(out)
......@@ -1750,15 +1729,6 @@
func : scatter_nd_add
backward : scatter_nd_add_grad
- op : searchsorted
args : (Tensor sorted_sequence, Tensor values, bool out_int32, bool right)
output : Tensor(out)
infer_meta :
func : SearchsortedInferMeta
kernel :
func : searchsorted
data_type : sorted_sequence
- op : segment_pool
args : (Tensor x, Tensor segment_ids, str pooltype)
output : Tensor(out), Tensor(summed_ids)
......@@ -1968,15 +1938,6 @@
data_type : x
backward : sum_grad
- op : svd
args : (Tensor x, bool full_matrices)
output : Tensor(u), Tensor(s), Tensor(vh)
infer_meta :
func : SvdInferMeta
kernel :
func : svd
backward : svd_grad
- op : swish
args : (Tensor x)
output : Tensor(out)
......@@ -1998,17 +1959,6 @@
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : take_along_axis
args : (Tensor arr, Tensor indices, int axis)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [indices]
kernel :
func : take_along_axis
data_type : arr
backward : take_along_axis_grad
- op : temporal_shift
args : (Tensor x, int seg_num, float shift_ratio, str data_format_str)
output : Tensor
......
......@@ -750,6 +750,13 @@
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : log_loss
backward : log_loss_grad
inputs :
{input : Predicted, label : Labels}
outputs :
out : Loss
- op : log_softmax
backward : log_softmax_grad
extra :
......@@ -916,6 +923,15 @@
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- op : put_along_axis
backward : put_along_axis_grad
inputs :
{arr : Input, indices : Index, values : Value}
outputs :
out : Result
attrs :
{axis : Axis, reduce : Reduce}
- op : qr
backward : qr_grad
inputs :
......@@ -1029,6 +1045,12 @@
extra :
attrs : [bool use_mkldnn = false]
- op : searchsorted
inputs :
{sorted_sequence : SortedSequence, values : Values}
outputs :
out : Out
- op : seed
extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]
......@@ -1176,6 +1198,13 @@
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : svd
backward : svd_grad
inputs :
x : X
outputs :
{u : U, s : S, vh : VH}
- op : swish
backward : swish_grad
extra :
......@@ -1186,6 +1215,15 @@
extra :
attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]
- op : take_along_axis
backward : take_along_axis_grad
inputs :
{arr : Input, indices : Index}
outputs :
out : Result
attrs :
axis : Axis
- op : tan
backward : tan_grad
inputs :
......
......@@ -628,6 +628,15 @@
func : log2
backward: log2_grad
- op : log_loss
args : (Tensor input, Tensor label, float epsilon)
output : Tensor
infer_meta :
func : LogLossInferMeta
kernel :
func : log_loss
backward : log_loss_grad
- op : logit
args : (Tensor x, float eps = 1e-6f)
output : Tensor
......@@ -741,6 +750,18 @@
func : poisson
backward : poisson_grad
- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [arr]
kernel :
func : put_along_axis
data_type : arr
inplace : (arr -> out)
backward : put_along_axis_grad
- op : qr
args : (Tensor x, str mode = "reduced")
output : Tensor(q), Tensor(r)
......@@ -800,6 +821,15 @@
inplace : (x -> out)
backward : rsqrt_grad
- op : searchsorted
args : (Tensor sorted_sequence, Tensor values, bool out_int32 = false, bool right = false)
output : Tensor(out)
infer_meta :
func : SearchsortedInferMeta
kernel :
func : searchsorted
data_type : sorted_sequence
- op : send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
......@@ -907,6 +937,26 @@
square_sr {selected_rows -> selected_rows}
backward : square_grad
- op : svd
args : (Tensor x, bool full_matrices = false)
output : Tensor(u), Tensor(s), Tensor(vh)
infer_meta :
func : SvdInferMeta
kernel :
func : svd
backward : svd_grad
- op : take_along_axis
args : (Tensor arr, Tensor indices, int axis)
output : Tensor
infer_meta :
func : TakeAlongAxisInferMeta
param : [arr, indices, axis]
kernel :
func : take_along_axis
data_type : arr
backward : take_along_axis_grad
- op : tan
args : (Tensor x)
output : Tensor
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad",
{"Predicted", "Labels", "Loss@GRAD"},
{"epsilon"},
{"Predicted@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(log_loss_grad, phi::LogLossGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("put_along_axis",
{"Input", "Index", "Value"},
{"Axis", "Reduce"},
{"Result"});
}
KernelSignature PutAlongAxisGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("put_along_axis_grad",
{"Input", "Index", "Result@GRAD"},
{"Axis", "Reduce"},
{"Input@GRAD", "Value@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(put_along_axis, phi::PutAlongAxisArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(put_along_axis_grad,
phi::PutAlongAxisGradArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SvdGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("svd_grad",
{"X", "U", "VH", "S", "U@GRAD", "VH@GRAD", "S@GRAD"},
{"full_matrices"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(svd_grad, phi::SvdGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TakeAlongAxisArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"take_along_axis", {"Input", "Index"}, {"Axis"}, {"Result"});
}
KernelSignature TakeAlongAxisGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("take_along_axis_grad",
{"Input", "Index", "Result@GRAD"},
{"Axis"},
{"Input@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(take_along_axis, phi::TakeAlongAxisArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(take_along_axis_grad,
phi::TakeAlongAxisGradArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册