未验证 提交 f7b80ada 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim add instance_norm custom vjp] (#52935)

上级 79a01d6c
......@@ -21,6 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
......@@ -140,6 +143,48 @@ phi::KernelKey InstanceNormDoubleGradOp::GetExpectedKernelType(
ctx.GetPlace());
}
class InstanceNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// inputs and outputs of batch_norm
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor scale = this->GetSingleForwardInput("Scale");
paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean");
paddle::Tensor saved_variance =
this->GetSingleForwardOutput("SavedVariance");
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");
auto x_grad_ptr = this->GetOutputPtr(&x_grad);
std::string x_grad_name = this->GetOutputName(x_grad);
auto scale_grad_ptr = this->GetOutputPtr(&scale_grad);
std::string scale_grad_name = this->GetOutputName(scale_grad);
auto bias_grad_ptr = this->GetOutputPtr(&bias_grad);
std::string bias_grad_name = this->GetOutputName(bias_grad);
auto epsilon = this->Attr<float>("epsilon");
VLOG(3) << "Runing instance_norm composite func";
prim::instance_norm_grad<prim::DescTensor>(x,
scale,
saved_mean,
saved_variance,
y_grad,
epsilon,
x_grad_ptr,
scale_grad_ptr,
bias_grad_ptr);
this->RecoverOutputName(x_grad, x_grad_name);
this->RecoverOutputName(scale_grad, scale_grad_name);
this->RecoverOutputName(bias_grad, bias_grad_name);
}
};
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
{"DY", "DDY"});
......@@ -163,7 +208,8 @@ REGISTER_OPERATOR(instance_norm,
ops::InstanceNormOpInferVarType,
ops::InstanceNormGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormGradMaker<paddle::imperative::OpBase>,
InstanceNormInferShapeFunctor);
InstanceNormInferShapeFunctor,
ops::InstanceNormCompositeGradOpMaker);
REGISTER_OPERATOR(instance_norm_grad,
ops::InstanceNormGradOp,
ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>,
......
......@@ -1493,6 +1493,58 @@ void batch_norm_grad(const Tensor& x,
}
}
template <typename T>
void instance_norm_grad(const Tensor& x,
const paddle::optional<Tensor>& scale,
const Tensor& saved_mean,
const Tensor& saved_variance,
const Tensor& y_grad,
float epsilon,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
const int n = x.dims()[0];
const int c = x.dims()[1];
const int h = x.dims()[2];
const int w = x.dims()[3];
Tensor x_hat;
Tensor std_inv;
if (scale_grad || x_grad) {
auto mean = reshape<T>(saved_mean, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
std_inv = reshape<T>(saved_variance, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
x_hat = (x - mean) * std_inv;
}
// x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad *
// x_hat).mean((h,w)))
if (x_grad) {
auto scale_t =
reshape<T>(scale.get_ptr() ? scale.get()
: full<T>(IntArray({c}), 1., x.dtype()),
IntArray({1, c, 1, 1}))
.tile(IntArray({n, 1, h, w}));
set_output<T>(
(scale_t * std_inv) *
(y_grad -
y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) -
(x_hat *
((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) /
(h * w)))),
x_grad);
}
// scale_grad = x_hat * y_grad.sum(n, h, w)
if (scale_grad) {
set_output<T>((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad);
}
// d_bias = y_grad.sum(n, h, w)
if (bias_grad) {
set_output<T>(y_grad.sum(IntArray({0, 2, 3})), bias_grad);
}
}
template <typename T>
void gelu_grad(const Tensor& x,
const Tensor& out_grad,
......
......@@ -496,6 +496,7 @@
data_type : x
optional : scale
backward : instance_norm_double_grad
composite: instance_norm_grad(x, scale, saved_mean, saved_variance, y_grad, epsilon, x_grad, scale_grad, bias_grad)
- backward_op : layer_norm_grad
forward : layer_norm (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) -> Tensor(out), Tensor(mean), Tensor(variance)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册