提交 da91e650 编写于 作者: M Megvii Engine Team

refactor(ops/layer_norm): speed up the host speed of layer_norm

GitOrigin-RevId: 6f359b5b295f3d340947e0f6ea948c0fc1c19886
上级 67cfce9f
......@@ -1939,6 +1939,11 @@ class LayerNormBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase);
DEF_OPR_PARAM(LayerNorm);
public:
MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl(
const TensorLayout& data, const Param& p, TensorLayout& dst,
TensorLayout& mean, TensorLayout& rstd);
protected:
void deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
......
......@@ -4,12 +4,11 @@
namespace megdnn {
void LayerNormBase::deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias);
auto p = param();
using Param = LayerNormBase::Param;
void LayerNormBase::deduce_layout_fwd_impl(
const TensorLayout& data, const Param& p, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd) {
TensorShape unnormalized_shape;
unnormalized_shape.ndim = data.ndim - p.normalized_dim;
for (size_t i = 0; i < unnormalized_shape.ndim; ++i) {
......@@ -22,6 +21,14 @@ void LayerNormBase::deduce_layout_fwd(
rstd = unnormalized_layout;
}
void LayerNormBase::deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias);
deduce_layout_fwd_impl(data, param(), dst, mean, rstd);
}
void LayerNormBase::check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
......
......@@ -63,6 +63,7 @@ __all__ = [
"hsigmoid",
"hswish",
"indexing_one_hot",
"layer_norm",
"leaky_relu",
"linear",
"local_conv2d",
......@@ -1135,9 +1136,6 @@ def layer_norm(
bias: must not be None when the affine is true
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""
if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)
if isinstance(normalized_shape, int):
normalized_shape = [normalized_shape]
......
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace layer_norm {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LayerNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::LayerNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<LayerNorm>();
size_t nr_inp = inputs.size();
auto p = op_def.param();
mgb_assert(
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
"num of inputs of pooling should be 1 or 3 but you give %zu",
inputs.size());
auto&& inp = inputs[0];
auto& inp_cn = inp.comp_node;
if (inp.layout.ndim == 0) {
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}}},
false};
}
TensorLayout oup_layout, mean_layout, rstd_layout;
megdnn::LayerNorm::deduce_layout_fwd_impl(
inp.layout, p, oup_layout, mean_layout, rstd_layout);
return {{{oup_layout, inp_cn, {}},
{mean_layout, inp_cn, {}},
{rstd_layout, inp_cn, {}}},
true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<LayerNorm>();
size_t nr_inp = inputs.size();
auto p = op_def.param();
mgb_assert(
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
"num of inputs of pooling should be 1 or 3 but you give %zu",
inputs.size());
auto cn = inputs[0]->comp_node();
DnnOprCaller<megdnn::LayerNorm> caller(cn);
auto&& dnn_opr = caller.op;
dnn_opr->param() = p;
TensorLayout oup_layout, mean_layout, rstd_layout;
megdnn::LayerNorm::deduce_layout_fwd_impl(
inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout);
DeviceTensorND out_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout);
DeviceTensorND mean_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, mean_layout);
DeviceTensorND rstd_devtensor =
BlobManager::inst()->alloc_workspace_with_defrag(cn, rstd_layout);
megdnn::Workspace dnn_wk;
auto wk_size = caller.op->get_workspace_in_bytes(
inputs[0]->dnn_tensor().layout,
p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(),
p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout,
mean_layout, rstd_layout);
if (wk_size != 0) {
TensorLayout w_layout({wk_size}, dtype::Byte());
dnn_wk = caller.create_workspace(w_layout);
}
dnn_opr->exec(
inputs[0]->dnn_tensor(),
p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(),
p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(),
out_devtensor.as_megdnn(), mean_devtensor.as_megdnn(),
rstd_devtensor.as_megdnn(), dnn_wk);
return {Tensor::make(out_devtensor), Tensor::make(mean_devtensor),
Tensor::make(rstd_devtensor)};
}
OP_TRAIT_REG(LayerNorm, LayerNorm)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace layer_norm
} // namespace mgb::imperative
\ No newline at end of file
......@@ -8,7 +8,6 @@
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
......@@ -729,28 +728,4 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lrn
namespace layer_norm {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LayerNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::LayerNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}
OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();
} // namespace layer_norm
} // namespace mgb::imperative
......@@ -289,6 +289,28 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs);
}
ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
// avoid the amp_dtype_autocast
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype;
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
return imperative::apply(op, inputs);
}
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);
......@@ -319,6 +341,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
register_dtype_promote_rule<LayerNorm>(layer_norm_rule);
}
} register_helper;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册