提交 5092f529 编写于 作者: C chengduoZH

Separate GPU and CPU implementation

上级 e0333735
...@@ -21,6 +21,13 @@ using Tensor = framework::Tensor; ...@@ -21,6 +21,13 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
template <typename T>
using EigenMatrixMapRowMajor = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
class LayerNormOp : public framework::OperatorWithKernel { class LayerNormOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -101,7 +108,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,7 +108,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Layer Normalization. Layer Normalization.
Layer Norm has been implemented as discussed in the paper: Layer Norm has been implemented as discussed in the paper:
https://arxiv.org/abs/1607.06450 https://arxiv.org/abs/1607.06450
... ...
...@@ -109,6 +115,75 @@ https://arxiv.org/abs/1607.06450 ...@@ -109,6 +115,75 @@ https://arxiv.org/abs/1607.06450
} }
}; };
template <typename T>
class LayerNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto *output = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
output->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
auto squre = [](T ele) { return ele * ele; };
auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
mean_map = input_map.rowwise().mean();
var_map = (input_map - mean_map.replicate(1, right))
.unaryExpr(squre)
.rowwise()
.mean()
.unaryExpr(add_epslion);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
auto inv_std = var_map.unaryExpr(inv_std_func);
if (scale && bias) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
} else if (scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1));
} else if (bias) {
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right)) +
bias_map.replicate(left, 1);
} else {
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right));
}
}
};
class LayerNormGradOp : public framework::OperatorWithKernel { class LayerNormGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -161,6 +236,115 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -161,6 +236,115 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class LayerNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = ConstEigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().sum();
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
auto d_scale_map =
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y"
d_scale_map =
((x_map - mean_map.replicate(1, right))
.cwiseProduct(
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map))
.colwise()
.sum();
}
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_map = var_map.unaryExpr(inv_std_func).eval();
// TODO(zcd): these code can be refined
if (d_scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx
auto dx_end =
inv_std_map.replicate(1, right).cwiseProduct(d_y_map).cwiseProduct(
scale_map.replicate(left, 1));
// dy_dmean_dx
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = inv_std_map.replicate(1, right).cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean =
(T(-1.0) / right) * dx_end.rowwise().sum().replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = inv_std_map.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -78,7 +78,7 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -78,7 +78,7 @@ class LayerNormKernel : public framework::OpKernel<T> {
auto *var = ctx.Output<Tensor>("Variance"); auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto &x_dims = x.dims(); const auto x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace()); mean->mutable_data<T>(ctx.GetPlace());
...@@ -87,11 +87,12 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -87,11 +87,12 @@ class LayerNormKernel : public framework::OpKernel<T> {
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]); int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]); int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right}); framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape); x.Resize(matrix_shape);
y->Resize(matrix_shape); Tensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean; math::RowwiseMean<DeviceContext, T> row_mean;
...@@ -101,30 +102,24 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -101,30 +102,24 @@ class LayerNormKernel : public framework::OpKernel<T> {
// functor-> get variance // functor-> get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), y); ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), &out);
row_mean(dev_ctx, *y, var); row_mean(dev_ctx, out, var);
// functor-> get norm_out // functor-> get norm_out
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), y); ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &out);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), ctx, &out, var, /*axis*/ 0,
y); DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &out);
framework::DDim scale_shape({right});
if (scale) { if (scale) {
Tensor scale_matrix = *scale;
scale_matrix.Resize(scale_shape);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor<T>(), y); ctx, &out, scale, /*axis*/ 1, MulFunctor<T>(), &out);
} }
if (bias) { if (bias) {
Tensor bias_matrix = *bias;
bias_matrix.Resize(scale_shape);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor<T>(), y); ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
} }
y->Resize(x_dims);
} }
}; };
...@@ -184,6 +179,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -184,6 +179,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
if (d_x) { if (d_x) {
framework::DDim vec_shape({left}); framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto dx_dim = d_x->dims();
Tensor temp_vec; Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace()); temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
...@@ -227,6 +223,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> { ...@@ -227,6 +223,7 @@ class LayerNormGradKernel : public framework::OpKernel<T> {
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, &var, /*axis*/ 0, ctx, d_x, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x); DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
d_x->Resize(dx_dim);
} }
} }
}; };
......
...@@ -62,9 +62,9 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): ...@@ -62,9 +62,9 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
grad_x = dx_end + d_mean + d_std grad_x = dx_end + d_mean + d_std
grad_y.shape = x_shape grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape
x.shape = x_shape
scale.shape = scale_shape scale.shape = scale_shape
var.shape, mean.shape = [N, ], [N, ]
return grad_x, d_scale, d_bias return grad_x, d_scale, d_bias
...@@ -112,10 +112,7 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -112,10 +112,7 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
class TestLayerNormdOp(OpTest): class TestLayerNormdOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue( self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
np.allclose(
np.array(tensor).reshape(np_array.shape), np_array, atol=atol),
msg)
def __assert_grad_close(self, def __assert_grad_close(self,
tensor, tensor,
...@@ -123,7 +120,7 @@ class TestLayerNormdOp(OpTest): ...@@ -123,7 +120,7 @@ class TestLayerNormdOp(OpTest):
name, name,
place, place,
max_relative_error=0.02): max_relative_error=0.02):
a = np.array(tensor).reshape(np_array.shape) a = np.array(tensor)
b = np_array b = np_array
abs_a = np.abs(a) abs_a = np.abs(a)
abs_a[abs_a < 1e-5] = 1 abs_a[abs_a < 1e-5] = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册