提交 c0ccd0ea 编写于 作者: M Megvii Engine Team

feat(mge/bn): add NHWC support for bn

GitOrigin-RevId: 0a5bb6f72df8862bfa4c6afcc31561fd691a5ecd
上级 3d3666b6
...@@ -1122,7 +1122,8 @@ def batch_norm( ...@@ -1122,7 +1122,8 @@ def batch_norm(
momentum: float = 0.9, momentum: float = 0.9,
eps: float = 1e-5, eps: float = 1e-5,
inplace: bool = True, inplace: bool = True,
compute_mode="default" compute_mode="default",
param_dim="dim_1c11"
): ):
r"""Applies batch normalization to the input. r"""Applies batch normalization to the input.
...@@ -1147,16 +1148,23 @@ def batch_norm( ...@@ -1147,16 +1148,23 @@ def batch_norm(
if inp.ndim != 4: if inp.ndim != 4:
raise NotImplementedError("batch_norm for ndim != 4") raise NotImplementedError("batch_norm for ndim != 4")
if param_dim == "dim_1c11":
C = inp.shape[1] C = inp.shape[1]
pshape = (1, C, 1, 1)
elif param_dim == "dim_111c":
C = inp.shape[3]
pshape = (1, 1, 1, C)
else:
raise ValueError("Invalid param_dim {}".format(param_dim))
def make_full_if_none(x, value): def make_full_if_none(x, value):
if x is None: if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
return result return result
elif x.ndim == 1: elif x.ndim == 1:
shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Reshape(), x, shape) (result,) = apply(builtin.Reshape(), x, shape)
return result return result
return x return x
...@@ -1183,19 +1191,19 @@ def batch_norm( ...@@ -1183,19 +1191,19 @@ def batch_norm(
if not training: if not training:
op = builtin.BatchNorm( op = builtin.BatchNorm(
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim
) )
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret return ret
else: else:
op = builtin.BatchNorm( op = builtin.BatchNorm(
avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim
) )
if has_mean or has_var: if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0) running_mean = make_full_if_none(running_mean, 0)
running_var = make_full_if_none(running_var, 1) running_var = make_full_if_none(running_var, 1)
new_mean, new_var, _, _, inp = apply( new_mean, new_var, *_, inp = apply(
op, inp, weight, bias, running_mean, running_var op, inp, weight, bias, running_mean, running_var
) )
if not has_mean: if not has_mean:
...@@ -1213,7 +1221,7 @@ def batch_norm( ...@@ -1213,7 +1221,7 @@ def batch_norm(
else: else:
return inp, new_mean, new_var return inp, new_mean, new_var
else: else:
(_, _, inp,) = apply(op, inp, weight, bias) inp = apply(op, inp, weight, bias)[-1]
return inp return inp
......
...@@ -27,6 +27,7 @@ class _BatchNorm(Module): ...@@ -27,6 +27,7 @@ class _BatchNorm(Module):
track_running_stats=True, track_running_stats=True,
freeze=False, freeze=False,
compute_mode="default", compute_mode="default",
param_dim="dim_1c11",
**kwargs **kwargs
): ):
super(_BatchNorm, self).__init__(**kwargs) super(_BatchNorm, self).__init__(**kwargs)
...@@ -38,6 +39,7 @@ class _BatchNorm(Module): ...@@ -38,6 +39,7 @@ class _BatchNorm(Module):
self._track_running_stats_saved = track_running_stats self._track_running_stats_saved = track_running_stats
self.freeze = freeze self.freeze = freeze
self.compute_mode = compute_mode self.compute_mode = compute_mode
self.param_dim = param_dim
if self.freeze: if self.freeze:
assert ( assert (
self._track_running_stats_saved self._track_running_stats_saved
...@@ -125,6 +127,7 @@ class _BatchNorm(Module): ...@@ -125,6 +127,7 @@ class _BatchNorm(Module):
momentum=exponential_average_factor, momentum=exponential_average_factor,
eps=self.eps, eps=self.eps,
compute_mode=self.compute_mode, compute_mode=self.compute_mode,
param_dim=self.param_dim,
) )
if _ndims != 4: if _ndims != 4:
......
...@@ -811,7 +811,8 @@ def test_batch_conv_bias(): ...@@ -811,7 +811,8 @@ def test_batch_conv_bias():
run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
def test_conv2d_io16c32(): def test_conv2d_autocast():
"""check amp's result is equal to manually converted result"""
amp.enabled = True amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
...@@ -918,11 +919,14 @@ def test_layer_norm(): ...@@ -918,11 +919,14 @@ def test_layer_norm():
assert abs(outvar.mean()) < 1e-7 assert abs(outvar.mean()) < 1e-7
def test_batchnorm2d_io16c32(): def test_batchnorm2d_autocast():
"""check amp's result is equal to manually converted result"""
amp.enabled = True amp.enabled = True
inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) tshape = (1, 224, 224, 3)
weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) pshape = (1, 1, 1, 3)
bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) inp = tensor(np.random.randn(*tshape), dtype=np.float32)
weight = tensor(np.ones(pshape, dtype=np.float32))
bias = tensor(np.zeros(pshape, dtype=np.float32))
out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
......
...@@ -51,16 +51,16 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -51,16 +51,16 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
// need running mean/variance // need running mean/variance
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING;
size_t nr_out = need_stat? 5 : 3; size_t nr_out = need_stat? 6 : 4;
SmallVector<LogicalTensorDesc> out_shapes(nr_out); SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0]; auto&& i0 = inputs[0];
auto&& i1 = inputs[1]; auto&& i1 = inputs[1];
// [running_mean, running_var,] save_mean, save_var // [running_mean, running_var,] save_mean, save_var
for (size_t i = 0; i < nr_out-1; ++ i) { for (size_t i = 0; i < nr_out-2; ++ i) {
out_shapes[i] = {i1.layout, i1.comp_node}; out_shapes[i] = {i1.layout, i1.comp_node};
} }
// output tensor out_shapes[nr_out-2] = {TensorLayout({0}, dtype::Byte()), i0.comp_node}; // reserve
out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; // output
return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0};
} }
......
...@@ -689,7 +689,8 @@ ProxyGraph::make_backward_graph( ...@@ -689,7 +689,8 @@ ProxyGraph::make_backward_graph(
output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()});
} }
auto output_grads = make_input_place_holders(output_descs); auto output_grads = make_input_place_holders(output_descs);
mgb_assert(output_grads.size() == output_has_grad.size()); mgb_assert(output_grads.size() == output_has_grad.size(), "%d vs %d",
output_grads.size(), output_has_grad.size());
bool any_input_has_grad = false; bool any_input_has_grad = false;
for (size_t i = 0; i < output_grads.size(); ++ i) { for (size_t i = 0; i < output_grads.size(); ++ i) {
if (!output_has_grad[i]) { if (!output_has_grad[i]) {
......
...@@ -207,7 +207,7 @@ TEST(TestImperative, BatchNormGrad) { ...@@ -207,7 +207,7 @@ TEST(TestImperative, BatchNormGrad) {
attr.param.write_pod(param); attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat},
{true, true, true, false, false}, {true, true, true, false, false},
{false, false, false, false, true}); {false, false, false, false, false, true});
} }
{ {
auto op = OprAttr::make("BatchNorm"); auto op = OprAttr::make("BatchNorm");
...@@ -216,7 +216,7 @@ TEST(TestImperative, BatchNormGrad) { ...@@ -216,7 +216,7 @@ TEST(TestImperative, BatchNormGrad) {
param.fwd_mode = Param::FwdMode::TRAINING; param.fwd_mode = Param::FwdMode::TRAINING;
attr.param.write_pod(param); attr.param.write_pod(param);
OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true},
{false, false, true}); {false, false, false, true});
} }
} }
......
...@@ -99,7 +99,7 @@ UNUSED void print(const char* s) { ...@@ -99,7 +99,7 @@ UNUSED void print(const char* s) {
OprChecker::OprChecker(std::shared_ptr<OpDef> opdef) OprChecker::OprChecker(std::shared_ptr<OpDef> opdef)
: m_op(opdef) {} : m_op(opdef) {}
void OprChecker::run(std::vector<InputSpec> inp_keys) { void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size(); size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps); SmallVector<HostTensorND> host_inp(nr_inps);
...@@ -151,6 +151,8 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) { ...@@ -151,6 +151,8 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) {
func->execute().wait(); // run last because it may contain inplace operations func->execute().wait(); // run last because it may contain inplace operations
for(size_t i = 0; i < nr_oups; ++ i) { for(size_t i = 0; i < nr_oups; ++ i) {
if (bypass.find(i) != bypass.end())
continue;
MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]); MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]);
} }
} }
......
...@@ -23,7 +23,7 @@ class OprChecker { ...@@ -23,7 +23,7 @@ class OprChecker {
public: public:
using InputSpec = std::variant<HostTensorND, TensorShape>; using InputSpec = std::variant<HostTensorND, TensorShape>;
OprChecker(std::shared_ptr<OpDef> opdef); OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes); void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass={});
private: private:
std::shared_ptr<OpDef> m_op; std::shared_ptr<OpDef> m_op;
}; };
......
...@@ -73,7 +73,7 @@ TEST(TestImperative, BatchNorm) { ...@@ -73,7 +73,7 @@ TEST(TestImperative, BatchNorm) {
TensorShape{1, C, 1, 1}, TensorShape{1, C, 1, 1},
TensorShape{1, C, 1, 1}, TensorShape{1, C, 1, 1},
TensorShape{1, C, 1, 1} TensorShape{1, C, 1, 1}
}); }, {4});
} }
TEST(TestImperative, Concat) { TEST(TestImperative, Concat) {
......
...@@ -1766,7 +1766,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { ...@@ -1766,7 +1766,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
x.dtype().name(), res.dtype().name()); x.dtype().name(), res.dtype().name());
} }
rewriter.replace_var( rewriter.replace_var(
opr->output(4), res.node(), opr->output(5), res.node(),
mgb_cstr_log( mgb_cstr_log(
"replace batch_norm(x, scale, bias, mean, " "replace batch_norm(x, scale, bias, mean, "
"varience) " "varience) "
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
#include <random> #include <random>
#include <vector>
#if MGB_CUDA #if MGB_CUDA
#include <cudnn.h> #include <cudnn.h>
...@@ -1665,6 +1666,9 @@ TEST(TestGoptInference, concatbypass) { ...@@ -1665,6 +1666,9 @@ TEST(TestGoptInference, concatbypass) {
TEST(TestGoptInference, ConvertBatchNormPass) { TEST(TestGoptInference, ConvertBatchNormPass) {
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
std::vector<TensorShape> shps = {{1, 3, 1, 1}, {1, 1, 1, 3}},
xshps = {{2, 3, 16, 24}, {2, 16, 24, 3}};
for (int t = 0; t < 2; t++) {
HostTensorGenerator<> gen(0, 1, 0); HostTensorGenerator<> gen(0, 1, 0);
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
...@@ -1676,9 +1680,10 @@ TEST(TestGoptInference, ConvertBatchNormPass) { ...@@ -1676,9 +1680,10 @@ TEST(TestGoptInference, ConvertBatchNormPass) {
.rename(name); .rename(name);
}; };
using Param = opr::BatchNorm::Param; using Param = opr::BatchNorm::Param;
Param param(Param::ParamDim::DIM_1C11, Param::FwdMode::INFERENCE); Param::ParamDim param_dim = t == 0 ? Param::ParamDim::DIM_1C11 : Param::ParamDim::DIM_111C;
TensorShape shp = {1, 3, 1, 1}; Param param(param_dim, Param::FwdMode::INFERENCE);
auto x = mkvar("x", {2, 3, 16, 24}), scale = mkcvar("scale", shp), TensorShape shp = shps[t], xshp = xshps[t];
auto x = mkvar("x", xshp), scale = mkcvar("scale", shp),
bias = mkcvar("bias", shp), mean = mkcvar("mean", shp); bias = mkcvar("bias", shp), mean = mkcvar("mean", shp);
auto host_variance = gen(shp, cn); auto host_variance = gen(shp, cn);
for (size_t i = 0; i < shp.total_nr_elems(); ++i) { for (size_t i = 0; i < shp.total_nr_elems(); ++i) {
...@@ -1687,7 +1692,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) { ...@@ -1687,7 +1692,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) {
} }
auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance) auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance)
.rename("variance"); .rename("variance");
auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[4]; auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[5];
SymbolVar y_opt; SymbolVar y_opt;
unpack_vector(gopt::optimize_for_inference( unpack_vector(gopt::optimize_for_inference(
{y}, gopt::OptimizeForInferenceOptions{}), {y}, gopt::OptimizeForInferenceOptions{}),
...@@ -1703,6 +1708,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) { ...@@ -1703,6 +1708,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) {
make_callback_copy(y_opt, host_y_opt)}); make_callback_copy(y_opt, host_y_opt)});
func->execute(); func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
} }
TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {
......
...@@ -62,10 +62,12 @@ BatchNormForward::BatchNormForward(VarNode *x, ...@@ -62,10 +62,12 @@ BatchNormForward::BatchNormForward(VarNode *x,
} }
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
add_input({x, scale, bias, mean, variance}); add_input({x, scale, bias, mean, variance});
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve
output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
// running mean/var
if (param.fwd_mode == Param::FwdMode::INFERENCE) { if (param.fwd_mode == Param::FwdMode::INFERENCE) {
auto mark_empty_var = [&](VarNode *var) { auto mark_empty_var = [&](VarNode *var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
...@@ -92,9 +94,10 @@ BatchNormForward::BatchNormForward(VarNode *x, ...@@ -92,9 +94,10 @@ BatchNormForward::BatchNormForward(VarNode *x,
{x, scale, bias}} {x, scale, bias}}
{ {
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
add_input({x, scale, bias}); add_input({x, scale, bias});
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve
output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
auto mark_empty_var = [&](VarNode *var) { auto mark_empty_var = [&](VarNode *var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT); .add_flag(VarNode::Flag::VOLATILE_CONTENT);
...@@ -151,7 +154,7 @@ BatchNormForward::do_make_node_prop() const { ...@@ -151,7 +154,7 @@ BatchNormForward::do_make_node_prop() const {
void BatchNormForward::scn_do_execute() { void BatchNormForward::scn_do_execute() {
auto &&x = input(0)->dev_tensor(); auto &&x = input(0)->dev_tensor();
auto &&y = output(4)->dev_tensor(); auto &&y = output(5)->dev_tensor();
if (need_stats()) { if (need_stats()) {
auto &&o0 = output(0)->dev_tensor(), auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(), &&o1 = output(1)->dev_tensor(),
...@@ -192,9 +195,10 @@ void BatchNormForward::scn_do_execute() { ...@@ -192,9 +195,10 @@ void BatchNormForward::scn_do_execute() {
} }
auto save_mean = output(2)->dev_tensor().as_megdnn(); auto save_mean = output(2)->dev_tensor().as_megdnn();
auto save_variance = output(3)->dev_tensor().as_megdnn(); auto save_variance = output(3)->dev_tensor().as_megdnn();
auto reserve = output(4)->dev_tensor().as_megdnn();
auto workspace = intl::get_megdnn_workspace_from_var(output().back()); auto workspace = intl::get_megdnn_workspace_from_var(output().back());
megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance,
save_mean, save_variance, y.as_megdnn(), workspace); save_mean, save_variance, reserve, y.as_megdnn(), workspace);
} }
void BatchNormForward::add_input_layout_constraint() { void BatchNormForward::add_input_layout_constraint() {
...@@ -209,17 +213,24 @@ void BatchNormForward::get_output_var_shape( ...@@ -209,17 +213,24 @@ void BatchNormForward::get_output_var_shape(
"got input dim: %zu, scale dim: %zu, bias dim: %zu", "got input dim: %zu, scale dim: %zu, bias dim: %zu",
inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim);
size_t inp_c = inp_shape[0][1], size_t channel_idx;
scale_c = inp_shape[1][1], if (param().param_dim == Param::ParamDim::DIM_111C) {
bias_c = inp_shape[2][1]; channel_idx = 3;
} else {
channel_idx = 1;
}
size_t inp_c = inp_shape[0][channel_idx],
scale_c = inp_shape[1][channel_idx],
bias_c = inp_shape[2][channel_idx];
mgb_assert(inp_c == scale_c && inp_c == bias_c, mgb_assert(inp_c == scale_c && inp_c == bias_c,
"inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu",
inp_c, scale_c, bias_c); inp_c, scale_c, bias_c);
out_shape[4] = inp_shape[0]; out_shape[5] = inp_shape[0];
for (size_t i = 0; i < 4; ++ i) { for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1]; out_shape[i] = inp_shape[1];
} }
out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
if (!need_stats()) { if (!need_stats()) {
out_shape[0] = out_shape[1] = {0}; out_shape[0] = out_shape[1] = {0};
} }
...@@ -231,7 +242,7 @@ size_t BatchNormForward::get_workspace_size_bytes( ...@@ -231,7 +242,7 @@ size_t BatchNormForward::get_workspace_size_bytes(
#define in(x) {input_shapes[x], input(x)->dtype()} #define in(x) {input_shapes[x], input(x)->dtype()}
#define out(x) {output_shapes[x], output(x)->dtype()} #define out(x) {output_shapes[x], output(x)->dtype()}
return megdnn_opr()->get_workspace_in_bytes( return megdnn_opr()->get_workspace_in_bytes(
in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4)); in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4), out(5));
#undef in #undef in
#undef out #undef out
} }
...@@ -249,7 +260,8 @@ void BatchNormForward::init_output_dtype() { ...@@ -249,7 +260,8 @@ void BatchNormForward::init_output_dtype() {
for (size_t i = 2; i < nr_inp; ++ i) { for (size_t i = 2; i < nr_inp; ++ i) {
mgb_assert(input(1)->dtype() == input(i)->dtype()); mgb_assert(input(1)->dtype() == input(i)->dtype());
} }
output(4)->dtype(input(0)->dtype()); output(4)->dtype(dtype::Byte()); // reserve
output(5)->dtype(input(0)->dtype()); // output
for (size_t i = 0; i < 4; ++ i) { for (size_t i = 0; i < 4; ++ i) {
output(i)->dtype(input(1)->dtype()); output(i)->dtype(input(1)->dtype());
} }
...@@ -271,9 +283,10 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { ...@@ -271,9 +283,10 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) {
switch (opr.param().fwd_mode) { switch (opr.param().fwd_mode) {
case BatchNorm::Param::FwdMode::TRAINING: case BatchNorm::Param::FwdMode::TRAINING:
grad = BatchNormBackward::make( grad = BatchNormBackward::make(
opr.input(0), out_grad[4], opr.input(0), out_grad[5],
opr.output(2), opr.output(3), opr.output(2), opr.output(3),
opr.input(1), opr.param()); opr.input(1), opr.output(4), // reserve
opr.param());
for (size_t i = 0; i < 3; ++ i) { for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node(); ret[i] = grad[(i + 2) % 3].node();
} }
...@@ -281,13 +294,13 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { ...@@ -281,13 +294,13 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) {
case BatchNorm::Param::FwdMode::INFERENCE: case BatchNorm::Param::FwdMode::INFERENCE:
auto sqrt_var = PowC::make((SymbolVar{opr.input(4)} auto sqrt_var = PowC::make((SymbolVar{opr.input(4)}
+ static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config()); + static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config());
auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} * auto d_bn_scale_unreduced = SymbolVar{out_grad[5]} *
(SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var; (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var;
auto d_bn_scale = Reduce::make(d_bn_scale_unreduced, auto d_bn_scale = Reduce::make(d_bn_scale_unreduced,
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1))); Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1)));
auto d_bn_bias = Reduce::make(out_grad[4], auto d_bn_bias = Reduce::make(out_grad[5],
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2))); Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2)));
auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var; auto dx = SymbolVar{out_grad[5]} * SymbolVar{opr.input(1)} / sqrt_var;
ret[0] = dx.node(); ret[0] = dx.node();
ret[1] = d_bn_scale.node(); ret[1] = d_bn_scale.node();
...@@ -302,26 +315,26 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); ...@@ -302,26 +315,26 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward);
BatchNormBackward::BatchNormBackward(VarNode *x, BatchNormBackward::BatchNormBackward(VarNode *x,
VarNode *y_grad, VarNode *save_mean, VarNode *y_grad, VarNode *save_mean,
VarNode* save_variance, VarNode *scale, VarNode* save_variance, VarNode *scale, VarNode *reserve,
const Param &param, const OperatorNodeConfig &config): const Param &param, const OperatorNodeConfig &config):
Super({x->owner_graph(), config, "batch_norm_bwd", Super({x->owner_graph(), config, "batch_norm_bwd",
{x, y_grad, save_mean, save_variance, scale}}, {x, y_grad, save_mean, save_variance, scale, reserve}},
0, true) 0, true)
{ {
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
add_input({x, y_grad, save_mean, save_variance, scale}); add_input({x, y_grad, save_mean, save_variance, scale, reserve});
} }
SymbolVarArray BatchNormBackward::make(SymbolVar x, SymbolVarArray BatchNormBackward::make(SymbolVar x,
SymbolVar y_grad, SymbolVar save_mean, SymbolVar y_grad, SymbolVar save_mean,
SymbolVar save_variance, SymbolVar scale, SymbolVar save_variance, SymbolVar scale, SymbolVar reserve,
const Param &param, const Param &param,
const OperatorNodeConfig &config) { const OperatorNodeConfig &config) {
auto&& out = x.node() auto&& out = x.node()
->owner_graph() ->owner_graph()
->insert_opr(std::make_unique<BatchNormBackward>( ->insert_opr(std::make_unique<BatchNormBackward>(
x.node(), y_grad.node(), save_mean.node(), x.node(), y_grad.node(), save_mean.node(),
save_variance.node(), scale.node(), param, config)) save_variance.node(), scale.node(), reserve.node(), param, config))
->output(); ->output();
SymbolVarArray ret(out.size()); SymbolVarArray ret(out.size());
for (size_t i = 0; i < ret.size(); i++) { for (size_t i = 0; i < ret.size(); i++) {
...@@ -355,4 +368,11 @@ void BatchNormBackward::init_output_dtype() { ...@@ -355,4 +368,11 @@ void BatchNormBackward::init_output_dtype() {
output(2)->dtype(input(0)->dtype()); output(2)->dtype(input(0)->dtype());
} }
cg::OperatorNodeBase::NodeProp*
BatchNormBackward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(5),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -391,14 +391,14 @@ struct OprMaker<opr::BatchNorm, 0> { ...@@ -391,14 +391,14 @@ struct OprMaker<opr::BatchNorm, 0> {
}; };
template <> template <>
struct OprMaker<opr::BatchNormBackward, 5> { struct OprMaker<opr::BatchNormBackward, 6> {
using Param = opr::BatchNormBackward::Param; using Param = opr::BatchNormBackward::Param;
static cg::OperatorNodeBase* make(const Param& param, static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i, const cg::VarNodeArray& i,
ComputingGraph& graph, ComputingGraph& graph,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph); MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param, return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param,
config)[0] config)[0]
.node() .node()
->owner_opr(); ->owner_opr();
...@@ -576,7 +576,7 @@ using ConvBiasForwardV4 = ConvBiasForward; ...@@ -576,7 +576,7 @@ using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0); MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0); MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 5); MGB_SEREG_OPR(BatchNormBackward, 6);
using LocalShareForwardV1 = LocalShareForward; using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData; using LocalShareBackwardDataV1 = LocalShareBackwardData;
......
...@@ -183,6 +183,10 @@ namespace { ...@@ -183,6 +183,10 @@ namespace {
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
#define _NR_INPUTS 6
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
} // anonymous namespace } // anonymous namespace
/* ======================= MegDNNOprWrapperFwd ======================= */ /* ======================= MegDNNOprWrapperFwd ======================= */
......
...@@ -24,7 +24,7 @@ namespace opr { ...@@ -24,7 +24,7 @@ namespace opr {
/* input: /* input:
* x, scale, bias, [running_mean, running_variance] * x, scale, bias, [running_mean, running_variance]
* output: * output:
* running_mean, running_variance, save_mean, save_inv_variance, y * running_mean, running_variance, save_mean, save_inv_variance, reserve, y
* *
* All params have the same definition with cudnn batch normalization. * All params have the same definition with cudnn batch normalization.
* *
...@@ -35,6 +35,9 @@ namespace opr { ...@@ -35,6 +35,9 @@ namespace opr {
* *
* For statistic(mean and variance) update: * For statistic(mean and variance) update:
* running_mean = (1 - moving_average) * running_mean + moving_average * new_mean * running_mean = (1 - moving_average) * running_mean + moving_average * new_mean
*
* Output reserve is used for cudnnBatchNormalizationForwardTrainingEx, and should
* be preserved for backward.
*/ */
MGB_DEFINE_OPR_CLASS(BatchNormForward, MGB_DEFINE_OPR_CLASS(BatchNormForward,
cg::OutshapePureByInshapeOpr< cg::OutshapePureByInshapeOpr<
...@@ -86,7 +89,7 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, ...@@ -86,7 +89,7 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
using BatchNorm = BatchNormForward; using BatchNorm = BatchNormForward;
/* input: /* input:
* x, y_grad, save_mean, save_inv_variance, scale * x, y_grad, save_mean, save_inv_variance, scale, reserve
* output: * output:
* scale_grad, bias_grad, x_grad * scale_grad, bias_grad, x_grad
*/ */
...@@ -97,15 +100,17 @@ MGB_DEFINE_OPR_CLASS(BatchNormBackward, ...@@ -97,15 +100,17 @@ MGB_DEFINE_OPR_CLASS(BatchNormBackward,
public: public:
BatchNormBackward(VarNode *x, VarNode *y_grad, BatchNormBackward(VarNode *x, VarNode *y_grad,
VarNode *save_mean, VarNode *save_variance, VarNode *save_mean, VarNode *save_variance,
VarNode *scale, VarNode *scale, VarNode *reserve,
const Param &param, const Param &param,
const OperatorNodeConfig &config); const OperatorNodeConfig &config);
static SymbolVarArray make(SymbolVar x, static SymbolVarArray make(SymbolVar x,
SymbolVar y_grad, SymbolVar save_mean, SymbolVar y_grad, SymbolVar save_mean,
SymbolVar save_variance, SymbolVar scale, SymbolVar save_variance, SymbolVar scale,
SymbolVar reserve,
const Param &param = {}, const Param &param = {},
const OperatorNodeConfig &config = {}); const OperatorNodeConfig &config = {});
private: private:
NodeProp* do_make_node_prop() const override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void init_output_dtype() override; void init_output_dtype() override;
}; };
......
...@@ -95,13 +95,13 @@ SymbolVarArray batch_norm(const SymbolVarArray& inputs, const Param &param) { ...@@ -95,13 +95,13 @@ SymbolVarArray batch_norm(const SymbolVarArray& inputs, const Param &param) {
SymbolVarArray ret; SymbolVarArray ret;
if (inputs.size() == 3) { if (inputs.size() == 3) {
ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], param); ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], param);
return {ret[4], ret[2], ret[3]}; return {ret[5], ret[2], ret[3]};
} }
else { else {
mgb_assert(inputs.size() == 5); mgb_assert(inputs.size() == 5);
ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2],
inputs[3], inputs[4], param); inputs[3], inputs[4], param);
return {ret[4], ret[0], ret[1]}; return {ret[5], ret[0], ret[1]};
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册