未验证 提交 d0ef6825 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #16274 from sneaxiy/fix_grad_maker

Remove unused variables in op grad maker
...@@ -34,7 +34,7 @@ DEFINE_double( ...@@ -34,7 +34,7 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors." "Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"); "Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, false, DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release " "Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends."); "immediately without waiting GPU kernel ends.");
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -344,6 +345,28 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -344,6 +345,28 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
ctx.GetPlace(), layout_, library_); ctx.GetPlace(), layout_, library_);
} }
class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOp().Type() + "_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
if (ForwardOp().Inputs().count("Bias") > 0) {
op->SetInput("Bias", Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
}
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -352,7 +375,7 @@ namespace ops = paddle::operators; ...@@ -352,7 +375,7 @@ namespace ops = paddle::operators;
// conv2d_transpose // conv2d_transpose
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ConvTransposeGradOpDescMaker);
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -368,7 +391,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -368,7 +391,7 @@ REGISTER_OP_CPU_KERNEL(
// conv3d_transpose // conv3d_transpose
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker, ops::Conv3DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ConvTransposeGradOpDescMaker);
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -384,7 +407,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -384,7 +407,7 @@ REGISTER_OP_CPU_KERNEL(
// depthwise conv2d_transpose // depthwise conv2d_transpose
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ConvTransposeGradOpDescMaker);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -106,21 +107,31 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -106,21 +107,31 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false, PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
"GradOp is only callable when is_test is false"); "GradOp is only callable when is_test is false");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null."); PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims,
"Dimensions of Input(X) and Out@Grad must be the same."); ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
auto mask_dims = ctx->GetInputDim("Mask"); ctx->ShareLoD(framework::GradVarName("Out"),
PADDLE_ENFORCE_EQ(x_dims, mask_dims, /*->*/ framework::GradVarName("X"));
"Dimensions of Input(X) and Mask must be the same."); }
};
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); class DropoutGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("dropout_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Mask", Output("Mask"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
} }
}; };
...@@ -129,7 +140,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -129,7 +140,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::DropoutGradOpDescMaker);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>, dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/operators/layer_norm_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -133,7 +134,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -133,7 +134,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
if (ctx->HasOutput(framework::GradVarName("Bias"))) { if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias")); ctx->GetInputDim("Scale"));
} }
} }
...@@ -157,12 +158,39 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -157,12 +158,39 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
}; };
class LayerNormGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("layer_norm_grad");
op->SetInput("X", Input("X"));
op->SetInput("Mean", Output("Mean"));
op->SetInput("Variance", Output("Variance"));
if (ForwardOp().Inputs().count("Scale") > 0) {
op->SetInput("Scale", Input("Scale"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
}
if (ForwardOp().Inputs().count("Bias") > 0) {
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
}
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::LayerNormGradOpDescMaker);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp); REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>, layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -245,11 +245,9 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -245,11 +245,9 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X"); auto x = *ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* mean = ctx.Input<Tensor>("Mean"); auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance"); auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y")); auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
...@@ -275,18 +273,13 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -275,18 +273,13 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
x.Resize(matrix_shape); x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace()); temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
if (!(bias && scale)) { temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
temp_norm.ShareDataWith(*y); // get x_norm
temp_norm.Resize(matrix_shape); ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
} else { ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace()); ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
// get x_norm ctx, &temp_norm, var, /*axis*/ 0,
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
} }
if (d_bias) { if (d_bias) {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -187,7 +188,6 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -187,7 +188,6 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetType("softmax_with_cross_entropy_grad"); grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", Input("Label")); grad_op->SetInput("Label", Input("Label"));
grad_op->SetInput("Softmax", Output("Softmax")); grad_op->SetInput("Softmax", Output("Softmax"));
grad_op->SetInput("Loss", Output("Loss"));
grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax")); grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册