未验证 提交 1c526e1d 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix some grad op desc makers (#16633)

* fix some grad op desc maker
test=develop

* fix grad op desc makers
test=develop
上级 ea2a2f77
...@@ -617,6 +617,25 @@ void OpDesc::Flush() { ...@@ -617,6 +617,25 @@ void OpDesc::Flush() {
static std::once_flag init_infer_shape_funcs; static std::once_flag init_infer_shape_funcs;
/**
* NOTE(paddle-dev): Very tricky code here. Maybe we should find a
* better way to register compile-time infershape method gentlely.
*
* Normally, we can register a class derived from InferShapeBase, so that
* we can set the field of `infer_shape_` inside OpInfo when registering op.
*
* However, there is another way we can set the field of `infer_shape_` inside
* OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
* running the following method InitInferShapeFuncs, `infer_shape_` would be set
* to be the InferShape method of OperatorWithKernel. That is to say, we borrow
* the run-time InferShape method of OperatorWithKernel to be the compile-time
* InferShape method.
*
* However, during compiling time, we may not know inputs, outputs and attrs of
* run-time OperatorWithKernel. So the following code creates a fake
* OperatorWithKernel object. That is why the field info_ of OperatorBase
* would be null.
*/
static void InitInferShapeFuncs() { static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] { std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance(); auto &map = OpInfoMap::Instance();
...@@ -628,11 +647,16 @@ static void InitInferShapeFuncs() { ...@@ -628,11 +647,16 @@ static void InitInferShapeFuncs() {
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered", PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type); op_type);
auto &op_info = it->second; auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered. if (op_info.infer_shape_) { // infer_shape has been registered.
continue; continue;
} }
auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
PADDLE_ENFORCE_NOT_NULL(
op, "InferShapeBase is not registered to Operator %s", op_type);
op_info.infer_shape_ = [op](InferShapeContext *ctx) { op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx); op->InferShape(ctx);
}; };
......
...@@ -8,9 +8,6 @@ conv_shift ...@@ -8,9 +8,6 @@ conv_shift
cos cos
cos_sim cos_sim
dequantize dequantize
elementwise_div
elementwise_max
elementwise_min
elu elu
fc fc
flatten flatten
...@@ -28,8 +25,6 @@ gelu ...@@ -28,8 +25,6 @@ gelu
gru gru
hard_shrink hard_shrink
hierarchical_sigmoid hierarchical_sigmoid
hinge_loss
huber_loss
leaky_relu leaky_relu
log log
logsigmoid logsigmoid
...@@ -57,7 +52,6 @@ requantize ...@@ -57,7 +52,6 @@ requantize
reshape reshape
rnn_memory_helper rnn_memory_helper
round round
row_conv
sequence_softmax sequence_softmax
sin sin
softplus softplus
......
...@@ -74,5 +74,8 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,5 +74,8 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
virtual void Apply() = 0; virtual void Apply() = 0;
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(BatchSizeLikeNoNeedBufferVarsInference,
"Input");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,10 +13,47 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Div"; }
std::string GetEquation() const override { return "Out = X / Y"; }
};
class ElementwiseDivGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_div_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y"); REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseDivGradOpDescMaker);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
......
...@@ -47,7 +47,7 @@ struct DivGradDX { ...@@ -47,7 +47,7 @@ struct DivGradDX {
template <typename T> template <typename T>
struct DivGradDY { struct DivGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return -dout * x / (y * y); return -dout * out / y;
} }
}; };
...@@ -58,13 +58,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { ...@@ -58,13 +58,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out"); auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto* x = dout; // Fake x, not used
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
} }
......
...@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Max"; }
std::string GetEquation() const override { return "Out = max(X, Y)"; }
};
class ElementwiseMaxGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_max_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpDescMaker);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_max, elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -63,10 +63,10 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> { ...@@ -63,10 +63,10 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>( ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
......
...@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Min"; }
std::string GetEquation() const override { return "Out = min(X, Y)"; }
};
class ElementwiseMinGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_min_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMinGradOpDescMaker);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_min, elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -62,10 +62,10 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> { ...@@ -62,10 +62,10 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>( ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
......
...@@ -173,12 +173,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -173,12 +173,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(out_grad_name),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim(out_grad_name);
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
...@@ -187,8 +187,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -187,8 +187,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name); ctx->ShareDim(out_grad_name, /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name); ctx->ShareDim("Y", /*->*/ y_grad_name);
......
...@@ -46,6 +46,7 @@ obtained from the `input` tensor. ...@@ -46,6 +46,7 @@ obtained from the `input` tensor.
)DOC"); )DOC");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -53,7 +54,8 @@ namespace ops = paddle::operators; ...@@ -53,7 +54,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_constant_batch_size_like, REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp, ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
ops::FillConstantBatchSizeLikeOpMaker); ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
......
...@@ -36,6 +36,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,6 +36,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "The input of fill-zeros-like op."); AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros."); AddOutput("Out", "The variable will be filled up with zeros.");
ExtraMake();
AddComment(R"DOC( AddComment(R"DOC(
FillZerosLike Operator. FillZerosLike Operator.
...@@ -44,13 +45,49 @@ The output will have the same size as the input. ...@@ -44,13 +45,49 @@ The output will have the same size as the input.
)DOC"); )DOC");
} }
protected:
virtual void ExtraMake() {}
};
class FillZerosLikeOp2 : public FillZerosLikeOp {
public:
using FillZerosLikeOp::FillZerosLikeOp;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
}; };
class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker {
protected:
void ExtraMake() override {
this->AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::VarType::FP32);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(FillZerosLikeOp2NoNeedBufferVarsInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker); ops::FillZerosLikeOpMaker);
REGISTER_OPERATOR(fill_zeros_like2, ops::FillZerosLikeOp2,
ops::FillZerosLikeOp2Maker,
ops::FillZerosLikeOp2NoNeedBufferVarsInference,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
...@@ -58,3 +95,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -58,3 +95,11 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
...@@ -26,3 +26,13 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -26,3 +26,13 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>, paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
...@@ -65,17 +65,13 @@ by input arguments. ...@@ -65,17 +65,13 @@ by input arguments.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
GaussianRandomBatchSizeLikeNoNeedBufferVarsInference, "Input");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR( REGISTER_OPERATOR(gaussian_random_batch_size_like,
gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp, paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker, paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker,
paddle::operators::GaussianRandomBatchSizeLikeNoNeedBufferVarsInference); paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu // Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/hinge_loss_op.h" #include "paddle/fluid/operators/hinge_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -97,12 +100,29 @@ class HingeLossGradOp : public framework::OperatorWithKernel { ...@@ -97,12 +100,29 @@ class HingeLossGradOp : public framework::OperatorWithKernel {
} }
}; };
class HingeLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("hinge_loss_grad");
op->SetInput("Logits", Input("Logits"));
op->SetInput("Labels", Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(hinge_loss, ops::HingeLossOp, ops::HingeLossOpMaker<float>, REGISTER_OPERATOR(hinge_loss, ops::HingeLossOp, ops::HingeLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>); ops::HingeLossGradOpDescMaker);
REGISTER_OPERATOR(hinge_loss_grad, ops::HingeLossGradOp); REGISTER_OPERATOR(hinge_loss_grad, ops::HingeLossGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
hinge_loss, hinge_loss,
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/huber_loss_op.h" #include "paddle/fluid/operators/huber_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -90,29 +93,36 @@ class HuberLossGradOp : public framework::OperatorWithKernel { ...@@ -90,29 +93,36 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Residual"),
"Input(Residual) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto residual_dims = ctx->GetInputDim("Residual"); auto residual_dims = ctx->GetInputDim("Residual");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, residual_dims);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->SetOutputDim(y_grad_name, residual_dims);
}
} }
};
class HuberLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("huber_loss_grad");
op->SetInput("Residual", Output("Residual"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
} }
}; };
...@@ -121,7 +131,7 @@ class HuberLossGradOp : public framework::OperatorWithKernel { ...@@ -121,7 +131,7 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>, REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>); ops::HuberLossGradOpDescMaker);
REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp); REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>, huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h" #include "paddle/fluid/operators/row_conv_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
namespace paddle { namespace paddle {
...@@ -54,7 +58,6 @@ class RowConvGradOp : public framework::OperatorWithKernel { ...@@ -54,7 +58,6 @@ class RowConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"), PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) should not be null."); "Input(Filter) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
...@@ -62,8 +65,8 @@ class RowConvGradOp : public framework::OperatorWithKernel { ...@@ -62,8 +65,8 @@ class RowConvGradOp : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, dout_dims);
} }
auto filter_grad_name = framework::GradVarName("Filter"); auto filter_grad_name = framework::GradVarName("Filter");
...@@ -259,12 +262,31 @@ class RowConvGradKernel<platform::CPUDeviceContext, T> ...@@ -259,12 +262,31 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
} }
} }
}; };
class RowConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("row_conv_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(row_conv, ops::RowConvOp, ops::RowConvOpMaker, REGISTER_OPERATOR(row_conv, ops::RowConvOp, ops::RowConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::RowConvGradOpDescMaker);
REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp); REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
row_conv, ops::RowConvKernel<paddle::platform::CPUDeviceContext, float>); row_conv, ops::RowConvKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -64,8 +64,9 @@ with random values sampled from a uniform distribution. ...@@ -64,8 +64,9 @@ with random values sampled from a uniform distribution.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OPERATOR(uniform_random_batch_size_like,
uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp, paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker); paddle::operators::UniformRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu // Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
...@@ -231,9 +231,16 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -231,9 +231,16 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for idx, op_desc in enumerate(op_descs): for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names(): for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set: if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", { x_in = _strip_grad_suffix_(arg)
"X": [_strip_grad_suffix_(arg)] x_in_var_desc = op_desc.block().find_var_recursive(
}, {"Out": [arg]}, {}), idx)) cpt.to_bytes(x_in))
assert x_in_var_desc is not None, "Variable {} not found".format(
x_in)
dtype = x_in_var_desc.dtype()
to_insert.append(
(_create_op_desc_("fill_zeros_like2", {"X": [x_in]},
{"Out": [arg]}, {"dtype": dtype}), idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......
...@@ -23,6 +23,8 @@ from test_elementwise_sub_op import * ...@@ -23,6 +23,8 @@ from test_elementwise_sub_op import *
from test_concat_op import * from test_concat_op import *
from test_gather_op import * from test_gather_op import *
from test_gaussian_random_batch_size_like_op import * from test_gaussian_random_batch_size_like_op import *
from test_uniform_random_batch_size_like_op import *
from test_fill_constant_batch_size_like_op import *
from test_lod_reset_op import * from test_lod_reset_op import *
from test_scatter_op import * from test_scatter_op import *
from test_mean_op import * from test_mean_op import *
...@@ -40,6 +42,7 @@ from test_sequence_unpad_op import * ...@@ -40,6 +42,7 @@ from test_sequence_unpad_op import *
from test_sequence_scatter_op import * from test_sequence_scatter_op import *
from test_sequence_slice_op import * from test_sequence_slice_op import *
from test_pad2d_op import * from test_pad2d_op import *
from test_fill_zeros_like2_op import *
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from op_test import OpTest
class TestFillZerosLike2Op(OpTest):
def setUp(self):
self.op_type = "fill_zeros_like2"
self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
self.attrs = {'dtype': convert_np_dtype_to_dtype_(self.dtype)}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestFillZerosLike2OpFp16(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float16
class TestFillZerosLike2OpFp64(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float64
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册