diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 6dd18277c9c074ab907ceabc98d953030a853bc7..edc26dfb96fbb40c4fa4464949ec020f6ed036ea 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/layer_norm_op.h" -#include "paddle/operators/elementwise_op_function.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -23,13 +21,6 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; -template -using EigenMatrixMapRowMajor = Eigen::Map< - Eigen::Matrix>; -template -using ConstEigenMatrixMapRowMajor = Eigen::Map< - const Eigen::Matrix>; - class LayerNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -118,75 +109,6 @@ https://arxiv.org/abs/1607.06450 } }; -template -class LayerNormKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const float epsilon = ctx.Attr("epsilon"); - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - const auto *x = ctx.Input("X"); - const auto &x_dims = x->dims(); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - auto *output = ctx.Output("Y"); - auto *mean = ctx.Output("Mean"); - auto *var = ctx.Output("Variance"); - output->mutable_data(ctx.GetPlace()); - mean->mutable_data(ctx.GetPlace()); - var->mutable_data(ctx.GetPlace()); - - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - - auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); - - auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); - auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); - auto output_map = EigenMatrixMapRowMajor(output->data(), left, right); - - auto squre = [](T ele) { return ele * ele; }; - auto add_epslion = [epsilon](T ele) { return ele + epsilon; }; - - mean_map = input_map.rowwise().mean(); - var_map = (input_map - mean_map.replicate(1, right)) - .unaryExpr(squre) - .rowwise() - .mean() - .unaryExpr(add_epslion); - - auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; - // TODO(zcd): Some thinking about output_map, is it appropriate that - // `output_map` and `input_map` point to the same memory. - auto inv_std = var_map.unaryExpr(inv_std_func); - if (scale && bias) { - auto scale_map = - ConstEigenMatrixMapRowMajor(scale->data(), 1, right); - auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); - output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)) + - bias_map.replicate(left, 1); - } else if (scale) { - auto scale_map = - ConstEigenMatrixMapRowMajor(scale->data(), 1, right); - output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)); - } else if (bias) { - auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); - output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std.replicate(1, right)) + - bias_map.replicate(left, 1); - } else { - output_map = (input_map - mean_map.replicate(1, right)) - .cwiseProduct(inv_std.replicate(1, right)); - } - } -}; - class LayerNormGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -239,115 +161,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel { } }; -template -class LayerNormGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *x = ctx.Input("X"); - const auto *mean = ctx.Input("Mean"); - const auto *var = ctx.Input("Variance"); - const auto *scale = ctx.Input("Scale"); - const auto *d_y = ctx.Input(framework::GradVarName("Y")); - - const auto &x_dims = x->dims(); - - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - - // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); - auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right); - auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1); - auto var_map = ConstEigenMatrixMapRowMajor(var->data(), left, 1); - - if (d_bias) { - d_bias->mutable_data(ctx.GetPlace()); - auto d_bias_map = EigenMatrixMapRowMajor(d_bias->data(), 1, right); - d_bias_map = d_y_map.colwise().sum(); - } - if (d_scale) { - d_scale->mutable_data(ctx.GetPlace()); - auto d_scale_map = - EigenMatrixMapRowMajor(d_scale->data(), 1, right); - auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; - // There are two equation to compute d_scale. One uses "Y" and the other - // does not use "Y" - d_scale_map = - ((x_map - mean_map.replicate(1, right)) - .cwiseProduct( - var_map.unaryExpr(inv_std_func).replicate(1, right)) - .cwiseProduct(d_y_map)) - .colwise() - .sum(); - } - - if (d_x) { - d_x->mutable_data(ctx.GetPlace()); - auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); - auto triple_product_func = [](T ele) { return ele * ele * ele; }; - auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; - - auto inv_std_map = var_map.unaryExpr(inv_std_func).eval(); - // TODO(zcd): these code can be refined - if (d_scale) { - auto scale_map = - ConstEigenMatrixMapRowMajor(scale->data(), 1, right); - // dy_dx - auto dx_end = - inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct( - scale_map.replicate(left, 1)); - - // dy_dmean_dx - auto dx_mean = - (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right); - - // dy_var_dx - auto dvar_end_part = (x_map - mean_map.replicate(1, right)) - .cwiseProduct(scale_map.replicate(left, 1)) - .cwiseProduct(d_y_map) - .rowwise() - .sum(); - auto dvar_end = inv_std_map.unaryExpr(triple_product_func) - .cwiseProduct(dvar_end_part) - .replicate(1, right); - auto dx_var = - (T(-1.0) / right) * - (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); - - d_x_map = dx_end + dx_mean + dx_var; - } else { - // dy_dx - auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map); - - // dy_dmean_dx - auto dx_mean = - (T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right); - - // dy_var_dx - auto dvar_end_part = (x_map - mean_map.replicate(1, right)) - .cwiseProduct(d_y_map) - .rowwise() - .sum(); - auto dvar_end = inv_std_map.unaryExpr(triple_product_func) - .cwiseProduct(dvar_end_part) - .replicate(1, right); - auto dx_var = - (T(-1.0) / right) * - (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); - - d_x_map = dx_end + dx_mean + dx_var; - } - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/operators/layer_norm_op.cu b/paddle/operators/layer_norm_op.cu index a84f5a41eae4343d622ffcba1306b90739b6f0bd..77d13b216f0e8d6d4434742908437f1eb74818c9 100644 --- a/paddle/operators/layer_norm_op.cu +++ b/paddle/operators/layer_norm_op.cu @@ -12,234 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/elementwise_op_function.h" #include "paddle/operators/layer_norm_op.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DataLayout = framework::DataLayout; - -namespace { -template -struct SubAndSquareFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } -}; - -template -struct DivAndSqrtFunctor { - explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } - inline HOSTDEVICE T operator()(T a, T b) const { - return a / (sqrt(b) + epsilon_); - } - - private: - T epsilon_; -}; - -template -struct MulFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } -}; - -template -struct AddFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } -}; - -template -struct SubFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } -}; - -template -struct MulInvVarFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { - return a * std::sqrt(1.0 / b); - } -}; -} // namespace - -template -class LayerNormCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const float epsilon = ctx.Attr("epsilon"); - auto *scale = ctx.Input("Scale"); - auto *bias = ctx.Input("Bias"); - auto x = *ctx.Input("X"); - - auto *y = ctx.Output("Y"); - auto *mean = ctx.Output("Mean"); - auto *var = ctx.Output("Variance"); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - const auto &x_dims = x.dims(); - - y->mutable_data(ctx.GetPlace()); - mean->mutable_data(ctx.GetPlace()); - var->mutable_data(ctx.GetPlace()); - - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - - framework::DDim matrix_shape({left, right}); - - x.Resize(matrix_shape); - y->Resize(matrix_shape); - - auto &dev_ctx = ctx.template device_context(); - math::RowwiseMean row_mean; - - // functor-> get mean - row_mean(dev_ctx, x, mean); - - // functor-> get variance - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), y); - row_mean(dev_ctx, *y, var); - - // functor-> get norm_out - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, mean, /*axis*/ 0, SubFunctor(), y); - ElementwiseComputeEx, DeviceContext, T>( - ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), - y); - - framework::DDim scale_shape({right}); - if (scale) { - Tensor scale_matrix = *scale; - scale_matrix.Resize(scale_shape); - ElementwiseComputeEx, DeviceContext, T>( - ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor(), y); - } - if (bias) { - Tensor bias_matrix = *bias; - bias_matrix.Resize(scale_shape); - ElementwiseComputeEx, DeviceContext, T>( - ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor(), y); - } - y->Resize(x_dims); - } -}; - -template -class LayerNormCUDAGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const float epsilon = ctx.Attr("epsilon"); - auto x = *ctx.Input("X"); - auto mean = *ctx.Input("Mean"); - auto var = *ctx.Input("Variance"); - auto scale = *ctx.Input("Scale"); - auto d_y = *ctx.Input(framework::GradVarName("Y")); - const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); - - // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - const auto &x_dims = x.dims(); - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int left = static_cast(matrix_dim[0]); - int right = static_cast(matrix_dim[1]); - framework::DDim matrix_shape({left, right}); - - d_y.Resize(matrix_shape); - auto &dev_ctx = ctx.template device_context(); - math::ColwiseSum colwise_sum; - - Tensor temp; - Tensor temp_norm; - if (d_scale || d_x) { - x.Resize(matrix_shape); - temp.mutable_data(matrix_shape, ctx.GetPlace()); - temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); - - // get x_norm - ElementwiseComputeEx, DeviceContext, T>( - ctx, &x, &mean, /*axis*/ 0, SubFunctor(), &temp_norm); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); - } - - if (d_bias) { - d_bias->mutable_data(ctx.GetPlace()); - colwise_sum(dev_ctx, d_y, d_bias); - } - if (d_scale) { - d_scale->mutable_data(ctx.GetPlace()); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor(), &temp); - colwise_sum(dev_ctx, temp, d_scale); - } - - if (d_x) { - framework::DDim vec_shape({left}); - d_x->mutable_data(ctx.GetPlace()); - Tensor temp_vec; - temp_vec.mutable_data(vec_shape, ctx.GetPlace()); - - auto &dev_ctx = ctx.template device_context(); - math::RowwiseMean row_mean; - - if (d_scale) { - // dy_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &d_y, &scale, /*axis*/ 1, MulFunctor(), &temp); - framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); - - // dy_dmean_dx - row_mean(dev_ctx, temp, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); - - // dy_var_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); - - } else { - // dy_dx - framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); - - // dy_dmean_dx - row_mean(dev_ctx, d_y, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); - - // dy_var_dx - ElementwiseComputeEx, DeviceContext, T>( - ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); - } - // dy_var_dx - row_mean(dev_ctx, temp, &temp_vec); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp_norm); - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor(), d_x); - - ElementwiseComputeEx, DeviceContext, T>( - ctx, d_x, &var, /*axis*/ 0, - DivAndSqrtFunctor(static_cast(epsilon)), d_x); - } - } -}; - -} // namespace operators -} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( layer_norm, - ops::LayerNormCUDAKernel, - ops::LayerNormCUDAKernel); + ops::LayerNormKernel, + ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, - ops::LayerNormCUDAGradKernel, - ops::LayerNormCUDAGradKernel); + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); diff --git a/paddle/operators/layer_norm_op.h b/paddle/operators/layer_norm_op.h index bca35b91e6f52d35dee14aac9d080b52914942e3..309f1b87a26c3f5cd9782f9a884b79e02972a820 100644 --- a/paddle/operators/layer_norm_op.h +++ b/paddle/operators/layer_norm_op.h @@ -16,19 +16,219 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/elementwise_op_function.h" +#include "paddle/operators/math/math_function.h" + namespace paddle { namespace operators { +template +struct SubAndSquareFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } +}; + +template +struct DivAndSqrtFunctor { + explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } + inline HOSTDEVICE T operator()(T a, T b) const { + return a / (sqrt(b) + epsilon_); + } + + private: + T epsilon_; +}; + +template +struct MulFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } +}; + +template +struct AddFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + +template +struct SubFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } +}; + +template +struct MulInvVarFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { + return a * std::sqrt(1.0 / b); + } +}; + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + template class LayerNormKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto *scale = ctx.Input("Scale"); + auto *bias = ctx.Input("Bias"); + auto x = *ctx.Input("X"); + + auto *y = ctx.Output("Y"); + auto *mean = ctx.Output("Mean"); + auto *var = ctx.Output("Variance"); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + + const auto &x_dims = x.dims(); + + y->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + + framework::DDim matrix_shape({left, right}); + + x.Resize(matrix_shape); + y->Resize(matrix_shape); + + auto &dev_ctx = ctx.template device_context(); + math::RowwiseMean row_mean; + + // functor-> get mean + row_mean(dev_ctx, x, mean); + + // functor-> get variance + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), y); + row_mean(dev_ctx, *y, var); + + // functor-> get norm_out + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, mean, /*axis*/ 0, SubFunctor(), y); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), + y); + + framework::DDim scale_shape({right}); + if (scale) { + Tensor scale_matrix = *scale; + scale_matrix.Resize(scale_shape); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor(), y); + } + if (bias) { + Tensor bias_matrix = *bias; + bias_matrix.Resize(scale_shape); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor(), y); + } + y->Resize(x_dims); + } }; template class LayerNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto x = *ctx.Input("X"); + auto mean = *ctx.Input("Mean"); + auto var = *ctx.Input("Variance"); + auto scale = *ctx.Input("Scale"); + auto d_y = *ctx.Input(framework::GradVarName("Y")); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + const auto &x_dims = x.dims(); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + framework::DDim matrix_shape({left, right}); + + d_y.Resize(matrix_shape); + auto &dev_ctx = ctx.template device_context(); + math::ColwiseSum colwise_sum; + + Tensor temp; + Tensor temp_norm; + if (d_scale || d_x) { + x.Resize(matrix_shape); + temp.mutable_data(matrix_shape, ctx.GetPlace()); + temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); + + // get x_norm + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, &mean, /*axis*/ 0, SubFunctor(), &temp_norm); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &var, /*axis*/ 0, + DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); + } + + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + colwise_sum(dev_ctx, d_y, d_bias); + } + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor(), &temp); + colwise_sum(dev_ctx, temp, d_scale); + } + + if (d_x) { + framework::DDim vec_shape({left}); + d_x->mutable_data(ctx.GetPlace()); + Tensor temp_vec; + temp_vec.mutable_data(vec_shape, ctx.GetPlace()); + + math::RowwiseMean row_mean; + + if (d_scale) { + // dy_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &d_y, &scale, /*axis*/ 1, MulFunctor(), &temp); + framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); + + // dy_dmean_dx + row_mean(dev_ctx, temp, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); + + // dy_var_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); + + } else { + // dy_dx + framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); + + // dy_dmean_dx + row_mean(dev_ctx, d_y, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); + + // dy_var_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); + } + // dy_var_dx + row_mean(dev_ctx, temp, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp_norm); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor(), d_x); + + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &var, /*axis*/ 0, + DivAndSqrtFunctor(static_cast(epsilon)), d_x); + } + } }; } // namespace operators