提交 a9374181 编写于 作者: M Megvii Engine Team

feat(mgb/opr): add layernorm forward and backward kernel

GitOrigin-RevId: 0cd484e753a4fbfb88cf81ddbde6ad80b844e69d
上级 bdca240f
......@@ -1936,6 +1936,75 @@ protected:
const TensorLayout& grad_s, size_t workspace_in_bytes);
};
class LayerNormBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
DEF_OPR_PARAM(LayerNorm);
protected:
void deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
void check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd);
};
class LayerNormForward : public LayerNormBase {
DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3);
public:
virtual void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
virtual size_t get_workspace_in_bytes(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd) = 0;
protected:
void check_exec(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd, size_t workspace_in_bytes);
};
using LayerNorm = LayerNormForward;
class LayerNormBackward : public LayerNormBase {
DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3);
public:
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
TensorLayout& dbias);
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias) = 0;
protected:
void check_exec(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -1212,3 +1212,10 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES]
)
)
(pdef('LayerNorm')
.add_fields('bool', 'affine', 'true')
.add_fields('float32', 'eps', '1e-5f')
.add_fields('uint64', 'normalized_dim', '1')
.add_fields('uint64', 'normalized_size', '1')
)
......@@ -209,7 +209,10 @@ private:
cb(LSQBackward) \
cb(Fill) \
cb(PaddingForward) \
cb(PaddingBackward)
cb(PaddingBackward) \
cb(LayerNormForward) \
cb(LayerNormBackward)
// clang-format on
/*!
......
/**
* \file dnn/src/common/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void LayerNormBase::deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias);
auto p = param();
TensorShape unnormalized_shape;
unnormalized_shape.ndim = data.ndim - p.normalized_dim;
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
unnormalized_shape.shape[i] = data.shape[i];
}
TensorLayout unnormalized_layout =
TensorLayout(unnormalized_shape, dtype::Float32());
dst = data;
mean = unnormalized_layout;
rstd = unnormalized_layout;
}
void LayerNormBase::check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(bias);
megdnn_assert_contiguous(dst);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
auto errmsg = [&]() {
return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " +
megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " +
megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd);
};
MEGDNN_MARK_USED_VAR(errmsg);
auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
lhs.format == rhs.format))
return false;
for (size_t i = 0; i < lhs.ndim; ++i) {
if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
return false;
}
}
return true;
};
megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str());
megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str());
megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());
auto p = param();
uint64_t normalized_dim = p.normalized_dim;
size_t unnormalized_dim = data.ndim - normalized_dim;
megdnn_assert(
normalized_dim < data.ndim,
"the dims of normalized shape should smaller than input dims");
for (size_t i = 0; i < unnormalized_dim; ++i) {
megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
}
if (p.affine) {
for (size_t i = 0; i < normalized_dim; ++i) {
megdnn_assert(
data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
errmsg().c_str());
}
}
}
void LayerNormForward::deduce_layout(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
deduce_layout_fwd(data, weight, bias, dst, mean, rstd);
}
void LayerNormForward::check_exec(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd,
size_t workspace_in_bytes) {
check_layout_fwd(data, weight, bias, dst, mean, rstd);
auto required_workspace_in_bytes =
get_workspace_in_bytes(data, weight, bias, dst, mean, rstd);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void LayerNormBackward::deduce_layout(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata,
TensorLayout& dweight, TensorLayout& dbias) {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(mean);
MEGDNN_MARK_USED_VAR(rstd);
ddata = data;
dweight = weight;
dbias = weight;
}
void LayerNormBackward::check_exec(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes) {
auto p = param();
auto required_workspace_in_bytes = get_workspace_in_bytes(
diff, data, weight, mean, rstd, ddata, dweight, dbias);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert_contiguous(diff);
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
megdnn_assert_contiguous(ddata);
if (p.affine) {
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(dweight);
megdnn_assert_contiguous(dbias);
}
auto errmsg = [&]() {
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " +
megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " +
megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " +
megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias);
};
MEGDNN_MARK_USED_VAR(errmsg);
auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool {
if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype &&
lhs.format == rhs.format))
return false;
for (size_t i = 0; i < lhs.ndim; ++i) {
if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) {
return false;
}
}
return true;
};
megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str());
megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str());
if (p.affine) {
megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str());
megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str());
}
size_t normalized_dim = p.normalized_dim;
size_t unnormalized_dim = data.ndim - normalized_dim;
for (size_t i = 0; i < unnormalized_dim; ++i) {
megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str());
}
if (p.affine) {
for (size_t i = 0; i < normalized_dim; ++i) {
megdnn_assert(
data.shape[unnormalized_dim + i] == weight.shape[i], "%s",
errmsg().c_str());
}
}
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
DEF(LayerNormForward, 6, true, true);
DEF(LayerNormBackward, 8, true, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -45,6 +45,7 @@
#include "src/cuda/images2neibs/opr_impl.h"
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h"
#include "src/cuda/indexing_one_hot/opr_impl.h"
#include "src/cuda/layer_norm/opr_impl.h"
#include "src/cuda/linspace/opr_impl.h"
#include "src/cuda/local/opr_impl.h"
#include "src/cuda/local_share/opr_impl.h"
......
此差异已折叠。
/**
* \file dnn/src/cuda/layer_norm/layer_norm.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cuda_runtime_api.h>
namespace megdnn {
namespace cuda {
namespace layer_norm {
template <typename T, typename T_ACC>
void forward(
T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean,
T_ACC* rstd, cudaStream_t stream);
template <typename T, typename T_ACC>
void backward(
const T* dY_data, const T* X_data, const T_ACC* mean_data,
const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data,
T* dgamma_data, T* dbeta_data, cudaStream_t stream);
} // namespace layer_norm
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/layer_norm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/layer_norm/opr_impl.h"
#include "src/cuda/layer_norm/layer_norm_cuda.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
void LayerNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);
auto p = param();
float eps = p.eps;
bool affine = p.affine;
uint64_t slice_length = p.normalized_size;
uint64_t slice_dim = p.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}
auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::layer_norm;
#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
forward<T, T_ACC>( \
data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \
affine ? bias.ptr<T>() : nullptr, static_cast<int64_t>(n_slices), \
static_cast<int64_t>(slice_length), static_cast<T_ACC>(eps), \
dst.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
void LayerNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
auto p = param();
bool affine = p.affine;
uint64_t slice_length = p.normalized_size;
uint64_t slice_dim = p.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}
auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::layer_norm;
#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
backward<T, T_ACC>( \
diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \
affine ? weight.ptr<T>() : nullptr, n_slices, slice_length, \
ddata.ptr<T>(), affine ? dweight.ptr<T>() : nullptr, \
affine ? dbias.ptr<T>() : nullptr, stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/layer_norm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
class LayerNormForwardImpl final : public LayerNormForward {
public:
using LayerNormForward::LayerNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class LayerNormBackwardImpl final : public LayerNormBackward {
public:
using LayerNormBackward::LayerNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -47,6 +47,7 @@
#include "src/naive/images2neibs/opr_impl.h"
#include "src/naive/indexing_multi_axis_vec/opr_impl.h"
#include "src/naive/indexing_one_hot/opr_impl.h"
#include "src/naive/layer_norm/opr_impl.h"
#include "src/naive/linspace/opr_impl.h"
#include "src/naive/local/opr_impl.h"
#include "src/naive/local_share/opr_impl.h"
......
/**
* \file dnn/src/naive/layer_norm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/layer_norm/opr_impl.h"
#include <algorithm>
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
namespace {
using Param = megdnn::LayerNorm::Param;
template <typename T, typename T_ACC = float>
void forward(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
const Param& param) {
float eps = param.eps;
bool affine = param.affine;
uint64_t slice_length = param.normalized_size;
uint64_t slice_dim = param.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}
for (size_t i = 0; i < n_slices; i++) {
T_ACC slice_sum = static_cast<T>(0.0f);
for (size_t j = 0; j < slice_length; j++) {
auto value = data.ptr<T>()[i * slice_length + j];
slice_sum += value;
}
T_ACC slice_mean = static_cast<T>(slice_sum / slice_length);
T_ACC slice_var = static_cast<T>(0.0f);
for (size_t j = 0; j < slice_length; j++) {
slice_var += (data.ptr<T>()[i * slice_length + j] - slice_mean) *
(data.ptr<T>()[i * slice_length + j] - slice_mean);
}
slice_var = slice_var / slice_length;
T_ACC slice_std = static_cast<T>(sqrt(slice_var + eps));
for (size_t j = 0; j < slice_length; j++) {
dst.ptr<T>()[i * slice_length + j] =
(data.ptr<T>()[i * slice_length + j] - slice_mean) / slice_std;
if (affine) {
dst.ptr<T>()[i * slice_length + j] =
dst.ptr<T>()[i * slice_length + j] * weight.ptr<T>()[j] +
bias.ptr<T>()[j];
}
}
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean);
rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(1.0 / slice_std);
}
}
template <typename T, typename T_ACC = float>
void backward(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param) {
bool affine = param.affine;
uint64_t slice_length = param.normalized_size;
uint64_t slice_dim = param.normalized_dim;
uint64_t n_slices = 1;
for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) {
n_slices = n_slices * data.layout.shape[i];
}
if (affine) {
for (size_t i = 0; i < slice_length; ++i) {
dweight.ptr<T>()[i] = 0;
dbias.ptr<T>()[i] = 0;
}
for (size_t i = 0; i < n_slices; ++i) {
for (size_t j = 0; j < slice_length; ++j) {
dweight.ptr<T>()[j] +=
(data.ptr<T>()[i * slice_length + j] - mean.ptr<T_ACC>()[i]) *
rstd.ptr<T_ACC>()[i] * diff.ptr<T>()[i * slice_length + j];
dbias.ptr<T>()[j] += diff.ptr<T>()[i * slice_length + j];
}
}
}
for (size_t i = 0; i < n_slices; ++i) {
T_ACC ds = static_cast<T_ACC>(0.0f);
T_ACC db = static_cast<T_ACC>(0.0f);
T_ACC a = static_cast<T_ACC>(0.0f);
T_ACC b = static_cast<T_ACC>(0.0f);
T_ACC c = static_cast<T_ACC>(0.0f);
for (size_t j = 0; j < slice_length; ++j) {
auto value = data.ptr<T>()[i * slice_length + j];
auto diff_v = diff.ptr<T>()[i * slice_length + j];
auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f);
db += diff_v * weight_v;
ds += diff_v * value * weight_v;
}
a = rstd.ptr<T_ACC>()[i];
b = (db * mean.ptr<T_ACC>()[i] - ds) * a * a * a / slice_length;
c = -b * mean.ptr<T_ACC>()[i] - db * a / slice_length;
for (uint64_t j = 0; j < slice_length; j++) {
auto weight_v = affine ? weight.ptr<T>()[j] : static_cast<T>(1.0f);
ddata.ptr<T>()[i * slice_length + j] =
diff.ptr<T>()[i * slice_length + j] * a * weight_v +
data.ptr<T>()[i * slice_length + j] * b + c;
}
}
}
} // namespace
namespace megdnn {
namespace naive {
void LayerNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \
data, weight, bias, dst, mean, rstd, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
void LayerNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \
diff, data, weight, mean, rstd, ddata, dweight, dbias, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/layer_norm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class LayerNormForwardImpl final : public LayerNormForward {
public:
using LayerNormForward::LayerNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class LayerNormBackwardImpl final : public LayerNormBackward {
public:
using LayerNormBackward::LayerNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -57,6 +57,15 @@ struct DeduceLayoutProxy<Opr, 5, true> {
}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 6, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 6);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
}
};
template <typename Opr>
struct DeduceLayoutProxy<Opr, 5, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
......
/**
* \file dnn/test/cuda/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/cuda/fixture.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, LAYERNORM_FORWARD) {
using Param = LayerNormForward::Param;
Param param;
param.affine = true;
param.eps = 1e-6;
param.normalized_dim = 1;
Checker<LayerNormForward> checker(handle_cuda());
checker.set_epsilon(1e-2);
auto run = [&](DType d) {
for (size_t n_slices : {10, 30})
for (size_t slice_len : {10, 30}) {
param.normalized_size = slice_len;
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, d)
.set_dtype(2, d)
.set_dtype(3, d)
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.execs({{n_slices, slice_len},
{slice_len},
{slice_len},
{n_slices, slice_len},
{n_slices},
{n_slices}});
}
};
run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}
TEST_F(CUDA, LAYERNORM_BACKWARD) {
using Param = LayerNormBackward::Param;
Param param;
param.affine = true;
param.eps = 1e-6;
param.normalized_dim = 1;
Checker<LayerNormBackward> checker(handle_cuda());
checker.set_epsilon(1e-1);
auto run = [&](DType d) {
for (size_t n_slices : {10, 30})
for (size_t slice_len : {10, 30}) {
param.normalized_size = slice_len;
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, d)
.set_dtype(2, d)
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_dtype(5, d)
.set_dtype(6, d)
.set_dtype(7, d)
.execs({{n_slices, slice_len},
{n_slices, slice_len},
{slice_len},
{n_slices},
{n_slices},
{n_slices, slice_len},
{slice_len},
{slice_len}});
}
};
run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -1066,57 +1066,6 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return cached / down
@lru_cache(maxsize=None)
def _get_layerNorm(device, dtype, dim, gopt_level=2):
@subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level)
def layerNormAffine(inputs, f, c):
inp, eps, _flatten_shape, weight, bias = inputs
inp_shape = f(GetVarShape(), inp)
inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
affine_oup = f(Reshape(), oup, inp_shape)
affine_oup = f("fma3", affine_oup, weight, bias)
# NOTE: return oup make backward faster but take more memory
return (affine_oup, oup, mean, x2s), (True, False, False, False)
@subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level)
def layerNorm(inputs, f, c):
inp, eps, _flatten_shape = inputs
inp_shape = f(GetVarShape(), inp)
inp = f(Reshape(axis=dim), inp, _flatten_shape)
mean = f(Reduce(mode="mean", axis=-1), inp)
x2s = f(Reduce(mode="sum_sqr", axis=-1), inp)
reduce_shape = f(GetVarShape(), x2s)
reduce_size = f(
"//",
f(Reduce(mode="product", axis=0), inp_shape),
f(Reduce(mode="product", axis=0), reduce_shape),
)
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2)))
inv_sqrt_var = f("**", f("+", var, eps), c(-0.5))
oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var))
oup = f(Reshape(), oup, inp_shape)
return (oup,), (True,)
return (layerNorm, layerNormAffine)
def layer_norm(
inp: Tensor,
normalized_shape: tuple,
......@@ -1133,32 +1082,34 @@ def layer_norm(
normalized_shape: the shape that you want to be normalizated
affine: whether to use weight and bias
weight: must not be None when the affine is true
bias: must not be None when the bias is true
bias: must not be None when the affine is true
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""
if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)
_device = inp.device
_dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape)
if isinstance(normalized_shape, int):
normalized_shape = [normalized_shape]
_flatten_shape = concat(
(
convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device),
convert_single_value(-1, dtype="int32", device=inp.device),
)
)
(layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim)
normalized_dim = len(normalized_shape)
assert normalized_dim > 0
eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
normalized_size = 1
for i in range(normalized_dim):
normalized_size = normalized_size * normalized_shape[i]
op = builtin.LayerNorm(
affine=affine,
eps=eps,
normalized_dim=normalized_dim,
normalized_size=normalized_size,
)
if affine:
outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias)
assert weight is not None and bias is not None
return apply(op, inp, weight, bias)[0]
else:
outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape)
return outvar
# assert weight is None and bias is None
return apply(op, inp)[0]
def batch_norm(
......
......@@ -865,61 +865,6 @@ def test_conv1d():
)
def test_layer_norm():
def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5):
__layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine)
__layer_norm.weight = weight
__layer_norm.bias = bias
return __layer_norm(x)
def _layer_norm_numpy(
x, normalized_shape, affine, weight=None, bias=None, eps=1e-5
):
x_shape = x.shape
dim_delta = len(x_shape) - len(normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)
mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + eps)
x = x.reshape(x_shape)
if affine:
x = weight * x + bias
return x
normalized_shape = (28, 28)
inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32")
weight = Tensor(np.random.randn(28, 28), dtype="float32")
bias = Tensor(np.random.randn(28, 28), dtype="float32")
inp_feat = inp_feat + 1
weight = weight + 1
bias = bias
affine = False
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias)
assert abs(outvar - targetvar).mean() < 1e-7
# no random, affine True
normalized_shape = (28, 28)
inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32")
weight = Tensor(np.ones((28, 28)), dtype="float32")
bias = Tensor(np.zeros((28, 28)), dtype="float32")
affine = True
outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias)
targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias)
assert abs((outvar - targetvar).mean()) < 1e-7
assert abs(outvar.mean()) < 1e-7
def test_batchnorm2d_autocast():
"""check amp's result is equal to manually converted result"""
amp.enabled = True
......
......@@ -43,7 +43,7 @@ def test_cross_entropy():
x = softmax(x)
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)
def test_cross_entropy_reduction():
......
......@@ -20,6 +20,7 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
......@@ -636,4 +637,29 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lrn
namespace layer_norm {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LayerNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::LayerNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}
OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();
} // namespace layer_norm
} // namespace mgb::imperative
......@@ -431,4 +431,6 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
def LRN: MgbHashableOp<"LRN", [LRNParam]>;
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
#endif // MGB_OPS
......@@ -16,6 +16,7 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
......@@ -420,6 +421,47 @@ struct OprMaker<opr::BatchNormBackward, 6> {
}
};
template <>
struct OprMaker<opr::LayerNorm, 0> {
using Param = opr::LayerNorm::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 1);
return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr();
}
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::LayerNormBackward, 0> {
using Param = opr::LayerNormBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 5) {
return opr::LayerNormBackward::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 4);
return opr::LayerNormBackward::make(
i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
}
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
template <typename Opr>
......@@ -641,6 +683,8 @@ MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
MGB_SEREG_OPR(LayerNorm, 0);
MGB_SEREG_OPR(LayerNormBackward, 0);
} // namespace opr
} // namespace mgb
......
/**
* \file src/opr/impl/dnn/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"
#include "../internal/megdnn_opr_wrapper.inl"
using namespace mgb;
using namespace opr;
/* ==================== LayerNormForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormForward);
LayerNormForward::LayerNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "layer_norm", {data, weight, bias}} {
init_megdnn_opr(*this, param);
add_input({data, weight, bias});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
LayerNormForward::LayerNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "layer_norm", {data}} {
init_megdnn_opr(*this, param);
add_input({data});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
SymbolVarArray LayerNormForward::make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param,
const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormForward>(
data.node(), weight.node(), bias.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray LayerNormForward::make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormForward>(
data.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void LayerNormForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
uint64_t normalized_dim = param().normalized_dim;
out_shape[0] = inp_shape[0];
TensorShape unnormalized_shape;
unnormalized_shape.ndim = inp_shape[0].ndim - normalized_dim;
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
unnormalized_shape.shape[i] = inp_shape[0].shape[i];
}
out_shape[1] = unnormalized_shape;
out_shape[2] = unnormalized_shape;
}
size_t LayerNormForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return 0;
}
void LayerNormForward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), {}, {},
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
}
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LayerNormForward) {
auto p = opr.param();
SymbolVarArray grad;
VarNodeArray ret;
if (p.affine) {
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx);
grad = LayerNormBackward::make(
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2),
opr.param());
} else {
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx);
grad = LayerNormBackward::make(
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param());
}
uint32_t nr_ret = p.affine ? 3 : 1;
for (uint32_t i = 0; i < nr_ret; ++i) {
ret.push_back(grad[i].node());
}
return ret;
}
#endif
/* ==================== LayerNormBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormBackward);
LayerNormBackward::LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"layer_norm_backward",
{diff, data, weight, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, weight, mean, rstd});
}
LayerNormBackward::LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param,
const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"layer_norm_backward",
{diff, data, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, mean, rstd});
auto mark_empty_var = [&](VarNode* var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(1));
mark_empty_var(output(2));
}
SymbolVarArray LayerNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormBackward>(
diff.node(), data.node(), weight.node(), mean.node(),
rstd.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray LayerNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<LayerNormBackward>(
diff.node(), data.node(), mean.node(), rstd.node(),
param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void LayerNormBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
if (param().affine) {
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2)));
} else {
TensorShape empty;
empty.ndim = 0;
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty));
}
this->init_output_static_infer_desc_workspace(false);
}
void LayerNormBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(2)->dtype());
}
size_t LayerNormBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return 0;
}
void LayerNormBackward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(), {});
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
{}, input(2)->dev_tensor().as_megdnn(),
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
{}, {}, {});
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/opr/include/megbrain/opr/dnn/layer_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LayerNormForward, intl::MegDNNOprWrapperFwd<megdnn::LayerNormForward>) // {
public:
MGE_WIN_DECLSPEC_FUC LayerNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC LayerNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void get_output_var_shape(
const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
using LayerNorm = LayerNormForward;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
LayerNormBackward, intl::MegDNNOprWrapperBwd<megdnn::LayerNormBackward>) // {
public:
MGE_WIN_DECLSPEC_FUC LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC LayerNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param = {}, const OperatorNodeConfig& config = {});
private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
} // namespace opr
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/opr/test/dnn/layer_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megdnn/oprs.h"
#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>
using namespace mgb;
namespace {
using Param = opr::LayerNormForward::Param;
void run_forward(bool is_affine, size_t normalized_size) {
using Checker = AutoOprChecker<3, 3>;
Param param;
param.eps = 1e-5;
param.affine = is_affine;
param.normalized_dim = 1;
param.normalized_size = normalized_size;
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto out = opr::LayerNormForward::make(inputs[0], inputs[1], inputs[2], param);
return {out[0], out[1], out[2]};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu()))
->create_operator<megdnn::LayerNormForward>();
auto inp_shape = inp[0]->shape();
auto n_slices = inp_shape[0];
auto slice_len = inp_shape[1];
opr->param() = param;
dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, slice_len});
dest[1].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices});
dest[2].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices});
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), {});
};
auto gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f);
src = *src_gen(src.shape(), src.comp_node());
};
Checker::RunOptions option;
option.numdiff_max_err = 1e-4;
Checker checker{make_graph, fwd};
checker.set_input_generator(0, gen);
checker.set_input_generator(1, gen);
checker.set_input_generator(2, gen);
checker.set_input_allow_grad(0, false);
checker.set_input_allow_grad(1, false);
checker.set_input_allow_grad(2, false);
checker.set_output_allow_grad(0, false);
checker.set_output_allow_grad(1, false);
checker.set_output_allow_grad(2, false);
checker.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option)
.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option)
.run({TensorShape{normalized_size, normalized_size},
TensorShape{normalized_size}, TensorShape{normalized_size}},
option);
}
TEST(TestOprDNN, LayerNormForwardAffine) {
REQUIRE_GPU(1);
run_forward(true, 1);
run_forward(true, 16);
run_forward(true, 17);
}
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -116,6 +116,7 @@ union OperatorParam {
param.Padding = 82,
param.ShuffleRNG = 83,
param.CheckNonFinite = 84,
param.LayerNorm = 85,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册