未验证 提交 1c526e1d 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix some grad op desc makers (#16633)

* fix some grad op desc maker
test=develop

* fix grad op desc makers
test=develop
上级 ea2a2f77
......@@ -617,6 +617,25 @@ void OpDesc::Flush() {
static std::once_flag init_infer_shape_funcs;
/**
* NOTE(paddle-dev): Very tricky code here. Maybe we should find a
* better way to register compile-time infershape method gentlely.
*
* Normally, we can register a class derived from InferShapeBase, so that
* we can set the field of `infer_shape_` inside OpInfo when registering op.
*
* However, there is another way we can set the field of `infer_shape_` inside
* OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
* running the following method InitInferShapeFuncs, `infer_shape_` would be set
* to be the InferShape method of OperatorWithKernel. That is to say, we borrow
* the run-time InferShape method of OperatorWithKernel to be the compile-time
* InferShape method.
*
* However, during compiling time, we may not know inputs, outputs and attrs of
* run-time OperatorWithKernel. So the following code creates a fake
* OperatorWithKernel object. That is why the field info_ of OperatorBase
* would be null.
*/
static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance();
......@@ -628,11 +647,16 @@ static void InitInferShapeFuncs() {
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered.
continue;
}
auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
PADDLE_ENFORCE_NOT_NULL(
op, "InferShapeBase is not registered to Operator %s", op_type);
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
......
......@@ -8,9 +8,6 @@ conv_shift
cos
cos_sim
dequantize
elementwise_div
elementwise_max
elementwise_min
elu
fc
flatten
......@@ -28,8 +25,6 @@ gelu
gru
hard_shrink
hierarchical_sigmoid
hinge_loss
huber_loss
leaky_relu
log
logsigmoid
......@@ -57,7 +52,6 @@ requantize
reshape
rnn_memory_helper
round
row_conv
sequence_softmax
sin
softplus
......
......@@ -74,5 +74,8 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
virtual void Apply() = 0;
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(BatchSizeLikeNoNeedBufferVarsInference,
"Input");
} // namespace operators
} // namespace paddle
......@@ -13,10 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Div"; }
std::string GetEquation() const override { return "Out = X / Y"; }
};
class ElementwiseDivGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_div_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseDivGradOpDescMaker);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......
......@@ -47,7 +47,7 @@ struct DivGradDX {
template <typename T>
struct DivGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return -dout * x / (y * y);
return -dout * out / y;
}
};
......@@ -58,13 +58,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto* x = dout; // Fake x, not used
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}
......
......@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Max"; }
std::string GetEquation() const override { return "Out = max(X, Y)"; }
};
class ElementwiseMaxGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_max_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpDescMaker);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -63,10 +63,10 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
......
......@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Min"; }
std::string GetEquation() const override { return "Out = min(X, Y)"; }
};
class ElementwiseMinGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_min_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMinGradOpDescMaker);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -62,10 +62,10 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
......
......@@ -173,12 +173,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
PADDLE_ENFORCE(ctx->HasInput(out_grad_name),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto x_dims = ctx->GetInputDim(out_grad_name);
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
......@@ -187,8 +187,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
ctx->ShareDim(out_grad_name, /*->*/ x_grad_name);
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
......
......@@ -46,6 +46,7 @@ obtained from the `input` tensor.
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -53,7 +54,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker,
ops::FillConstantBatchSizeLikeOpMaker);
ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
......
......@@ -36,6 +36,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros.");
ExtraMake();
AddComment(R"DOC(
FillZerosLike Operator.
......@@ -44,13 +45,49 @@ The output will have the same size as the input.
)DOC");
}
protected:
virtual void ExtraMake() {}
};
class FillZerosLikeOp2 : public FillZerosLikeOp {
public:
using FillZerosLikeOp::FillZerosLikeOp;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker {
protected:
void ExtraMake() override {
this->AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::VarType::FP32);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(FillZerosLikeOp2NoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OPERATOR(fill_zeros_like2, ops::FillZerosLikeOp2,
ops::FillZerosLikeOp2Maker,
ops::FillZerosLikeOp2NoNeedBufferVarsInference,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
......@@ -58,3 +95,11 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
......@@ -26,3 +26,13 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -65,17 +65,13 @@ by input arguments.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
GaussianRandomBatchSizeLikeNoNeedBufferVarsInference, "Input");
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(
gaussian_random_batch_size_like,
REGISTER_OPERATOR(gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::GaussianRandomBatchSizeLikeNoNeedBufferVarsInference);
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hinge_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -97,12 +100,29 @@ class HingeLossGradOp : public framework::OperatorWithKernel {
}
};
class HingeLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("hinge_loss_grad");
op->SetInput("Logits", Input("Logits"));
op->SetInput("Labels", Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hinge_loss, ops::HingeLossOp, ops::HingeLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::HingeLossGradOpDescMaker);
REGISTER_OPERATOR(hinge_loss_grad, ops::HingeLossGradOp);
REGISTER_OP_CPU_KERNEL(
hinge_loss,
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/huber_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -90,29 +93,36 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Residual"),
"Input(Residual) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto residual_dims = ctx->GetInputDim("Residual");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
ctx->SetOutputDim(x_grad_name, residual_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
ctx->SetOutputDim(y_grad_name, residual_dims);
}
}
};
class HuberLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("huber_loss_grad");
op->SetInput("Residual", Output("Residual"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
......@@ -121,7 +131,7 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::HuberLossGradOpDescMaker);
REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp);
REGISTER_OP_CPU_KERNEL(
huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
......@@ -54,7 +58,6 @@ class RowConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
......@@ -62,8 +65,8 @@ class RowConvGradOp : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(x_grad_name, x_dims);
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(x_grad_name, dout_dims);
}
auto filter_grad_name = framework::GradVarName("Filter");
......@@ -259,12 +262,31 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
}
}
};
class RowConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("row_conv_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(row_conv, ops::RowConvOp, ops::RowConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::RowConvGradOpDescMaker);
REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp);
REGISTER_OP_CPU_KERNEL(
row_conv, ops::RowConvKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -64,8 +64,9 @@ with random values sampled from a uniform distribution.
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
uniform_random_batch_size_like,
REGISTER_OPERATOR(uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker);
paddle::operators::UniformRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
......@@ -231,9 +231,16 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", {
"X": [_strip_grad_suffix_(arg)]
}, {"Out": [arg]}, {}), idx))
x_in = _strip_grad_suffix_(arg)
x_in_var_desc = op_desc.block().find_var_recursive(
cpt.to_bytes(x_in))
assert x_in_var_desc is not None, "Variable {} not found".format(
x_in)
dtype = x_in_var_desc.dtype()
to_insert.append(
(_create_op_desc_("fill_zeros_like2", {"X": [x_in]},
{"Out": [arg]}, {"dtype": dtype}), idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......
......@@ -23,6 +23,8 @@ from test_elementwise_sub_op import *
from test_concat_op import *
from test_gather_op import *
from test_gaussian_random_batch_size_like_op import *
from test_uniform_random_batch_size_like_op import *
from test_fill_constant_batch_size_like_op import *
from test_lod_reset_op import *
from test_scatter_op import *
from test_mean_op import *
......@@ -40,6 +42,7 @@ from test_sequence_unpad_op import *
from test_sequence_scatter_op import *
from test_sequence_slice_op import *
from test_pad2d_op import *
from test_fill_zeros_like2_op import *
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from op_test import OpTest
class TestFillZerosLike2Op(OpTest):
def setUp(self):
self.op_type = "fill_zeros_like2"
self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
self.attrs = {'dtype': convert_np_dtype_to_dtype_(self.dtype)}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestFillZerosLike2OpFp16(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float16
class TestFillZerosLike2OpFp64(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float64
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册