提交 e0333735 编写于 作者: C chengduoZH

unifid GPU and CPU implementation

上级 76e188e5
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/layer_norm_op.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -23,13 +21,6 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T>
using EigenMatrixMapRowMajor = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
class LayerNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -118,75 +109,6 @@ https://arxiv.org/abs/1607.06450
}
};
template <typename T>
class LayerNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto *output = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
output->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
auto squre = [](T ele) { return ele * ele; };
auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
mean_map = input_map.rowwise().mean();
var_map = (input_map - mean_map.replicate(1, right))
.unaryExpr(squre)
.rowwise()
.mean()
.unaryExpr(add_epslion);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
auto inv_std = var_map.unaryExpr(inv_std_func);
if (scale && bias) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
} else if (scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1));
} else if (bias) {
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right)) +
bias_map.replicate(left, 1);
} else {
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right));
}
}
};
class LayerNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -239,115 +161,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
};
template <typename T>
class LayerNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = ConstEigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().sum();
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
auto d_scale_map =
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y"
d_scale_map =
((x_map - mean_map.replicate(1, right))
.cwiseProduct(
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map))
.colwise()
.sum();
}
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_map = var_map.unaryExpr(inv_std_func).eval();
// TODO(zcd): these code can be refined
if (d_scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx
auto dx_end =
inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct(
scale_map.replicate(left, 1));
// dy_dmean_dx
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
}
}
}
};
} // namespace operators
} // namespace paddle
......
......@@ -12,234 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/layer_norm_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
namespace {
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b) + epsilon_);
}
private:
T epsilon_;
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
} // namespace
template <typename DeviceContext, typename T>
class LayerNormCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto x = *ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto &x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape);
y->Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
// functor-> get mean
row_mean(dev_ctx, x, mean);
// functor-> get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), y);
row_mean(dev_ctx, *y, var);
// functor-> get norm_out
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), y);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
y);
framework::DDim scale_shape({right});
if (scale) {
Tensor scale_matrix = *scale;
scale_matrix.Resize(scale_shape);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor<T>(), y);
}
if (bias) {
Tensor bias_matrix = *bias;
bias_matrix.Resize(scale_shape);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor<T>(), y);
}
y->Resize(x_dims);
}
};
template <typename DeviceContext, typename T>
class LayerNormCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X");
auto mean = *ctx.Input<Tensor>("Mean");
auto var = *ctx.Input<Tensor>("Variance");
auto scale = *ctx.Input<Tensor>("Scale");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto &x_dims = x.dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::ColwiseSum<DeviceContext, T> colwise_sum;
Tensor temp;
Tensor temp_norm;
if (d_scale || d_x) {
x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, &mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace());
Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
if (d_scale) {
// dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else {
// dy_dx
framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp_norm);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor<T>(), d_x);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormCUDAKernel<paddle::platform::CUDADeviceContext, double>);
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormCUDAGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -16,19 +16,219 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b) + epsilon_);
}
private:
T epsilon_;
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename DeviceContext, typename T>
class LayerNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto x = *ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto &x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape);
y->Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
// functor-> get mean
row_mean(dev_ctx, x, mean);
// functor-> get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), y);
row_mean(dev_ctx, *y, var);
// functor-> get norm_out
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), y);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
y);
framework::DDim scale_shape({right});
if (scale) {
Tensor scale_matrix = *scale;
scale_matrix.Resize(scale_shape);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor<T>(), y);
}
if (bias) {
Tensor bias_matrix = *bias;
bias_matrix.Resize(scale_shape);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor<T>(), y);
}
y->Resize(x_dims);
}
};
template <typename DeviceContext, typename T>
class LayerNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X");
auto mean = *ctx.Input<Tensor>("Mean");
auto var = *ctx.Input<Tensor>("Variance");
auto scale = *ctx.Input<Tensor>("Scale");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto &x_dims = x.dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::ColwiseSum<DeviceContext, T> colwise_sum;
Tensor temp;
Tensor temp_norm;
if (d_scale || d_x) {
x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, &mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace());
Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
math::RowwiseMean<DeviceContext, T> row_mean;
if (d_scale) {
// dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else {
// dy_dx
framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp_norm);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor<T>(), d_x);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
}
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册