提交 d63b7c60 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into add-python-layernorm

...@@ -62,7 +62,7 @@ class CompareOpKernel ...@@ -62,7 +62,7 @@ class CompareOpKernel
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis, ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
z); Functor(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MaxFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MinFunctor<T>(), z);
} }
}; };
......
...@@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} }
}; };
......
...@@ -365,10 +365,10 @@ template <typename Functor, typename DeviceContext, typename T, ...@@ -365,10 +365,10 @@ template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx, void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, int axis, const framework::Tensor* y, int axis, Functor func,
framework::Tensor* z) { framework::Tensor* z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor()); x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
......
...@@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
PowFunctor<T>(), z);
} }
}; };
......
...@@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} }
}; };
......
...@@ -21,13 +21,6 @@ using Tensor = framework::Tensor; ...@@ -21,13 +21,6 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
template <typename T>
using EigenMatrixMapRowMajor = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
class LayerNormOp : public framework::OperatorWithKernel { class LayerNormOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -108,7 +101,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -108,7 +101,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Layer Normalization. Layer Normalization.
Layer Norm has been implemented as discussed in the paper: Layer Norm has been implemented as discussed in the paper:
https://arxiv.org/abs/1607.06450 https://arxiv.org/abs/1607.06450
... ...
...@@ -116,75 +108,6 @@ https://arxiv.org/abs/1607.06450 ...@@ -116,75 +108,6 @@ https://arxiv.org/abs/1607.06450
} }
}; };
template <typename T>
class LayerNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto *output = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
output->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
auto squre = [](T ele) { return ele * ele; };
auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
mean_map = input_map.rowwise().mean();
var_map = (input_map - mean_map.replicate(1, right))
.unaryExpr(squre)
.rowwise()
.mean()
.unaryExpr(add_epslion);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
auto inv_std = var_map.unaryExpr(inv_std_func);
if (scale && bias) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
} else if (scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1));
} else if (bias) {
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right)) +
bias_map.replicate(left, 1);
} else {
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right));
}
}
};
class LayerNormGradOp : public framework::OperatorWithKernel { class LayerNormGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -237,125 +160,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -237,125 +160,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class LayerNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = ConstEigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().sum();
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
auto d_scale_map =
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y"
d_scale_map =
((x_map - mean_map.replicate(1, right))
.cwiseProduct(
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map))
.colwise()
.sum();
}
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): these code can be refined
if (d_scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1));
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1))
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -363,8 +167,9 @@ namespace ops = paddle::operators; ...@@ -363,8 +167,9 @@ namespace ops = paddle::operators;
REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
layer_norm_grad, ops::LayerNormGradOp); layer_norm_grad, ops::LayerNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm, layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>); ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm_grad, layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>); ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/layer_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -16,19 +16,222 @@ limitations under the License. */ ...@@ -16,19 +16,222 @@ limitations under the License. */
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b + epsilon_));
}
private:
T epsilon_;
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LayerNormKernel : public framework::OpKernel<T> { class LayerNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto x = *ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape);
Tensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
// get mean
row_mean(dev_ctx, x, mean);
// get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), &out);
row_mean(dev_ctx, out, var);
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &out);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &out, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &out);
if (scale) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &out, scale, /*axis*/ 1, MulFunctor<T>(), &out);
}
if (bias) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
}
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LayerNormGradKernel : public framework::OpKernel<T> { class LayerNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto &x_dims = x.dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::ColwiseSum<DeviceContext, T> colwise_sum;
Tensor temp;
Tensor temp_norm;
if (d_scale || d_x) {
x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
if (!(bias && scale)) {
temp_norm.ShareDataWith(*y);
temp_norm.Resize(matrix_shape);
} else {
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
}
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace());
auto dx_dim = d_x->dims();
Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
math::RowwiseMean<DeviceContext, T> row_mean;
if (d_scale) {
// dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else {
// dy_dx
framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp, /*axis*/ 0, SubFunctor<T>(), d_x);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
d_x->Resize(dx_dim);
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -331,6 +331,12 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>; ...@@ -331,6 +331,12 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>; template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>; template struct ColwiseSum<platform::CPUDeviceContext, double>;
template struct RowwiseSum<platform::CPUDeviceContext, float>;
template struct RowwiseSum<platform::CPUDeviceContext, double>;
template struct RowwiseMean<platform::CPUDeviceContext, float>;
template struct RowwiseMean<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -325,6 +325,31 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()( ...@@ -325,6 +325,31 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
vector->data<double>()); vector->data<double>());
} }
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
framework::Tensor one;
one.mutable_data<double>({size}, context.GetPlace());
SetConstant<platform::CUDADeviceContext, double> set;
set(context, &one, static_cast<double>(1.0));
gemv<platform::CUDADeviceContext, double>(
context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
1.0, one.data<double>(), input.data<double>(), 0.0,
vector->data<double>());
}
template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -128,6 +128,18 @@ struct ColwiseSum { ...@@ -128,6 +128,18 @@ struct ColwiseSum {
framework::Tensor* vec); framework::Tensor* vec);
}; };
template <typename DeviceContext, typename T>
struct RowwiseSum {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vec);
};
template <typename DeviceContext, typename T>
struct RowwiseMean {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vec);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -87,6 +87,88 @@ class ColwiseSum<platform::CPUDeviceContext, T> { ...@@ -87,6 +87,88 @@ class ColwiseSum<platform::CPUDeviceContext, T> {
} }
}; };
template <typename DeviceContext, typename T>
void RowwiseMean<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenVector<T>::Flatten(*out);
vec.device(*context.eigen_device()) = in.mean(Eigen::array<int, 1>({{1}}));
}
// TODO(zcd): Following ColwiseSum format, need to confirm.
// Specialize for CPU, since Eigen implement a general reduce. However,
// rowwise-sum can be easily implemented. General reduce has a huge overhead in
// CPU
template <typename T>
class RowwiseMean<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), height);
auto inv_size = 1.0 / size;
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < static_cast<size_t>(height); ++i) {
T sum = 0;
for (size_t j = 0; j < static_cast<size_t>(size); ++j) {
sum += in_buf[i * size + j];
}
out_buf[i] = sum * inv_size;
}
}
};
template <typename DeviceContext, typename T>
void RowwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenVector<T>::Flatten(*out);
vec.device(*context.eigen_device()) = in.sum(Eigen::array<int, 1>({{1}}));
}
// TODO(zcd): Following ColwiseSum format, need to confirm.
// Specialize for CPU, since Eigen implement a general reduce. However,
// rowwise-sum can be easily implemented. General reduce has a huge overhead in
// CPU
template <typename T>
class RowwiseSum<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < static_cast<size_t>(height); ++i) {
T sum = 0;
for (size_t j = 0; j < static_cast<size_t>(size); ++j) {
sum += in_buf[i * size + j];
}
out_buf[i] = sum;
}
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -300,6 +300,9 @@ class DistributeTranspiler: ...@@ -300,6 +300,9 @@ class DistributeTranspiler:
pass pass
return orig_shape return orig_shape
def _op_input_var(self, op, varname):
pass
def _is_op_on_pserver(self, endpoint, all_ops, idx): def _is_op_on_pserver(self, endpoint, all_ops, idx):
""" """
Recursively check if the op need to run on current server. Recursively check if the op need to run on current server.
...@@ -309,29 +312,35 @@ class DistributeTranspiler: ...@@ -309,29 +312,35 @@ class DistributeTranspiler:
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
op = all_ops[idx] op = all_ops[idx]
if op.inputs.has_key("Param"): input_names = set(op.input_names)
if op.inputs["Param"].name in param_names: # TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
if same_or_split_var(n, op.inputs[ if same_or_split_var(n, op.input("Param")[0]) \
"Param"].name) and n != op.inputs["Param"].name: and n != op.input("Param")[0]:
return True return True
return False return False
else: else:
j = idx - 1 j = idx - 1
while j >= 0: while j >= 0:
prev_op = all_ops[j] prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()] # prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()] # prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False found1 = False
found2 = False found2 = False
for _, v in op.inputs.iteritems(): for varname in op.desc.input_arg_names():
if v.name in prev_output_names: if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j) found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use. # later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems(): for varname in op.desc.output_arg_names():
if v.name in prev_input_names: if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j) found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2: if found1 or found2:
return True return True
...@@ -342,11 +351,11 @@ class DistributeTranspiler: ...@@ -342,11 +351,11 @@ class DistributeTranspiler:
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
for key, var in opt_op.inputs.iteritems(): for key in opt_op.input_names:
if key == "Grad": if key == "Grad":
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(g.name, var.name): if same_or_split_var(g.name, opt_op.input(key)[0]):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
...@@ -376,7 +385,7 @@ class DistributeTranspiler: ...@@ -376,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program # param is already created on global program
param_block = None param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]: for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, var.name): if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p param_block = p
break break
if not param_block: if not param_block:
...@@ -389,11 +398,12 @@ class DistributeTranspiler: ...@@ -389,11 +398,12 @@ class DistributeTranspiler:
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
for key, var in opt_op.inputs.iteritems(): for key in opt_op.input_names:
if key in ["Param", "Grad"]: if key in ["Param", "Grad"]:
continue continue
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = program.global_block().create_var(
...@@ -412,30 +422,44 @@ class DistributeTranspiler: ...@@ -412,30 +422,44 @@ class DistributeTranspiler:
shape=new_shape) shape=new_shape)
# change output's ParamOut variable # change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"] outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
for _, var in opt_op.inputs.iteritems(): inputs = self._get_input_map_from_op(self.program.global_block().vars,
program.global_block().create_var( opt_op)
name=var.name, for var in inputs.itervalues():
persistable=var.persistable, if type(var) == list:
dtype=var.dtype, varlist = var
shape=var.shape) else:
pserver_program.global_block().create_var( varlist = [var]
name=var.name, for var in varlist:
persistable=var.persistable, # TODO(typhoonzero): will remove below line later.
dtype=var.dtype, program.global_block().create_var(
shape=var.shape) name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
program.global_block().append_op( program.global_block().append_op(
type=opt_op.type, type=opt_op.type,
inputs=opt_op.inputs, inputs=inputs,
outputs=opt_op.outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
...@@ -472,7 +496,7 @@ class DistributeTranspiler: ...@@ -472,7 +496,7 @@ class DistributeTranspiler:
self.optimize_ops, idx) self.optimize_ops, idx)
if not is_op_on_pserver: if not is_op_on_pserver:
continue continue
if opt_op.inputs.has_key("Grad"): if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program, self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint) opt_op, endpoint)
else: else:
...@@ -499,6 +523,30 @@ class DistributeTranspiler: ...@@ -499,6 +523,30 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def _get_input_map_from_op(self, varmap, op):
iomap = dict()
for key in op.input_names:
vars = []
for varname in op.input(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def _get_output_map_from_op(self, varmap, op):
iomap = dict()
for key in op.output_names:
vars = []
for varname in op.output(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self, endpoint, pserver_program):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
...@@ -529,17 +577,21 @@ class DistributeTranspiler: ...@@ -529,17 +577,21 @@ class DistributeTranspiler:
# 2. rename op outputs # 2. rename op outputs
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict() new_outputs = dict()
# do not append startup op if var is not on this pserver # do not append startup op if var is not on this pserver
op_on_pserver = False op_on_pserver = False
for key, var in op.outputs.iteritems(): for key in op.output_names:
newname, _ = _get_splited_name_and_shape(var.name) newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname: if newname:
op_on_pserver = True op_on_pserver = True
new_outputs[key] = created_var_map[newname] new_outputs[key] = created_var_map[newname]
elif var.name in pserver_vars: elif op.output(key)[0] in pserver_vars:
op_on_pserver = True op_on_pserver = True
new_outputs[key] = pserver_vars[var.name] new_outputs[key] = pserver_vars[op.output(key)[0]]
# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)
if op_on_pserver: if op_on_pserver:
if op.type in [ if op.type in [
...@@ -548,7 +600,7 @@ class DistributeTranspiler: ...@@ -548,7 +600,7 @@ class DistributeTranspiler:
op.attrs["shape"] = new_outputs["Out"].shape op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.attrs) attrs=op.attrs)
return s_prog return s_prog
...@@ -20,6 +20,8 @@ import paddle.v2.fluid.core as core ...@@ -20,6 +20,8 @@ import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator from paddle.v2.fluid.op import Operator
from paddle.v2.fluid.framework import grad_var_name from paddle.v2.fluid.framework import grad_var_name
np.random.random(123)
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
x_shape = x.shape x_shape = x.shape
...@@ -62,9 +64,9 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): ...@@ -62,9 +64,9 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
grad_x = dx_end + d_mean + d_std grad_x = dx_end + d_mean + d_std
grad_y.shape = x_shape grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape
x.shape = x_shape
scale.shape = scale_shape scale.shape = scale_shape
var.shape, mean.shape = [N, ], [N, ]
return grad_x, d_scale, d_bias return grad_x, d_scale, d_bias
...@@ -112,10 +114,7 @@ def set_output_grad(scope, outputs, place, feed_dict=None): ...@@ -112,10 +114,7 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
class TestLayerNormdOp(OpTest): class TestLayerNormdOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue( self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
np.allclose(
np.array(tensor).reshape(np_array.shape), np_array, atol=atol),
msg)
def __assert_grad_close(self, def __assert_grad_close(self,
tensor, tensor,
...@@ -123,7 +122,7 @@ class TestLayerNormdOp(OpTest): ...@@ -123,7 +122,7 @@ class TestLayerNormdOp(OpTest):
name, name,
place, place,
max_relative_error=0.02): max_relative_error=0.02):
a = np.array(tensor).reshape(np_array.shape) a = np.array(tensor)
b = np_array b = np_array
abs_a = np.abs(a) abs_a = np.abs(a)
abs_a[abs_a < 1e-5] = 1 abs_a[abs_a < 1e-5] = 1
...@@ -151,7 +150,7 @@ class TestLayerNormdOp(OpTest): ...@@ -151,7 +150,7 @@ class TestLayerNormdOp(OpTest):
x_shape = shape x_shape = shape
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
scale_shape = [D] scale_shape = [D]
np.random.random(123)
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册