未验证 提交 77cf305f 编写于 作者: H hong 提交者: GitHub

Add batch norm yaml (#41386)

* update

* fix bug
上级 19cb0d18
...@@ -312,8 +312,8 @@ class InplaceABNGradKernel : public framework::OpKernel<T> { ...@@ -312,8 +312,8 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
phi::BatchNormGradRawKernel<T>( phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext< static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt, *y, *scale, *bias, mean_opt, variance_opt, *saved_mean, *saved_variance,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test, space_opt, *d_y, momentum, epsilon, data_layout, is_test,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x, use_global_stats, trainable_statistics, fuse_with_relu, true, d_x,
scale_grad, bias_grad); scale_grad, bias_grad);
} }
......
...@@ -140,10 +140,10 @@ class InplaceABNGradKernel ...@@ -140,10 +140,10 @@ class InplaceABNGradKernel
phi::BatchNormGradRawKernel<T>( phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext< static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt, *y, *scale, *bias, mean_opt, variance_opt, *saved_mean,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test, *saved_variance, space_opt, *d_y, momentum, epsilon, data_layout,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x, is_test, use_global_stats, trainable_statistics, fuse_with_relu, true,
scale_grad, bias_grad); d_x, scale_grad, bias_grad);
} }
} }
}; };
......
...@@ -167,6 +167,135 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -167,6 +167,135 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out; return out;
} }
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
kernel_data_type = ParseDataType(x);
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"batch_norm", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "batch_norm API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "batch_norm API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_scale = PrepareData(scale, kernel.InputAt(1), {});
auto input_bias = PrepareData(bias, kernel.InputAt(2), {});
auto input_mean = PrepareData(mean, kernel.InputAt(3), {});
auto input_variance = PrepareData(variance, kernel.InputAt(4), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(kernel_backend, &std::get<0>(api_output));
std::get<1>(api_output).set_impl(mean.impl());
std::get<2>(api_output).set_impl(variance.impl());
auto kernel_out_1 = SetKernelOutput(kernel_backend, &std::get<1>(api_output));
auto kernel_out_2 = SetKernelOutput(kernel_backend, &std::get<2>(api_output));
auto kernel_out_3 = SetKernelOutput(kernel_backend, &std::get<3>(api_output));
auto kernel_out_4 = SetKernelOutput(kernel_backend, &std::get<4>(api_output));
auto kernel_out_5 = SetKernelOutput(kernel_backend, &std::get<5>(api_output));
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MetaTensor meta_out_3(kernel_out_3);
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
phi::BatchNormInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_scale),
MakeMetaTensor(*input_bias),
MakeMetaTensor(*input_mean),
MakeMetaTensor(*input_variance),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
float,
float,
const std::string&,
bool,
bool,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{
(*kernel_fn)(*dev_ctx,
*input_x,
*input_scale,
*input_bias,
*input_mean,
*input_variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
}
return api_output;
}
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x, std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad, const Tensor& out_grad,
const Scalar& axis) { const Scalar& axis) {
......
...@@ -31,6 +31,20 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -31,6 +31,20 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections, const IntArray& num_or_sections,
const Scalar& axis); const Scalar& axis);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x, std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad, const Tensor& out_grad,
const Scalar& axis); const Scalar& axis);
......
...@@ -21,15 +21,15 @@ namespace phi { ...@@ -21,15 +21,15 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx, void BatchNormGradRawKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space, paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean, const DenseTensor& y_grad,
paddle::optional<const DenseTensor&> variance,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
...@@ -44,15 +44,15 @@ void BatchNormGradRawKernel(const Context& dev_ctx, ...@@ -44,15 +44,15 @@ void BatchNormGradRawKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx, void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space, paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean, const DenseTensor& y_grad,
paddle::optional<const DenseTensor&> variance,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
......
...@@ -37,15 +37,16 @@ using ConstEigenVectorArrayMap = ...@@ -37,15 +37,16 @@ using ConstEigenVectorArrayMap =
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& ctx, void BatchNormGradRawKernel(const Context& ctx,
const DenseTensor& y_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space, paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean, const DenseTensor& y_grad,
paddle::optional<const DenseTensor&> variance,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout_str, const std::string& data_layout_str,
...@@ -122,8 +123,8 @@ void BatchNormGradRawKernel(const Context& ctx, ...@@ -122,8 +123,8 @@ void BatchNormGradRawKernel(const Context& ctx,
ctx.template Alloc<T>(d_x); ctx.template Alloc<T>(d_x);
} }
const T* mean_data = saved_mean.data<T>(); const T* mean_data = nullptr;
const T* inv_var_data = saved_variance.data<T>(); const T* inv_var_data = nullptr;
DenseTensor inv_var_tensor; DenseTensor inv_var_tensor;
if (use_global_stats) { if (use_global_stats) {
const auto* running_mean = mean.get_ptr(); const auto* running_mean = mean.get_ptr();
...@@ -136,6 +137,9 @@ void BatchNormGradRawKernel(const Context& ctx, ...@@ -136,6 +137,9 @@ void BatchNormGradRawKernel(const Context& ctx,
inv_var_tmp = (var_arr + epsilon).sqrt().inverse(); inv_var_tmp = (var_arr + epsilon).sqrt().inverse();
inv_var_data = running_inv_var_data; inv_var_data = running_inv_var_data;
} else {
mean_data = saved_mean.data<T>();
inv_var_data = saved_variance.data<T>();
} }
ConstEigenVectorArrayMap<T> scale_arr(scale.data<T>(), C); ConstEigenVectorArrayMap<T> scale_arr(scale.data<T>(), C);
...@@ -293,15 +297,15 @@ void BatchNormGradRawKernel(const Context& ctx, ...@@ -293,15 +297,15 @@ void BatchNormGradRawKernel(const Context& ctx,
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx, void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& scale, const DenseTensor& scale,
const DenseTensor& bias, const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean, const DenseTensor& saved_mean,
const DenseTensor& saved_variance, const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space, paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean, const DenseTensor& y_grad,
paddle::optional<const DenseTensor&> variance,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string& data_layout, const std::string& data_layout,
...@@ -313,15 +317,15 @@ void BatchNormGradKernel(const Context& dev_ctx, ...@@ -313,15 +317,15 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx, BatchNormGradRawKernel<T, Context>(dev_ctx,
y_grad,
x, x,
scale, scale,
bias, bias,
mean,
variance,
saved_mean, saved_mean,
saved_variance, saved_variance,
reserve_space, reserve_space,
mean, y_grad,
variance,
momentum, momentum,
epsilon, epsilon,
data_layout, data_layout,
......
...@@ -306,15 +306,15 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( ...@@ -306,15 +306,15 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx, void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &y_grad,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &scale, const DenseTensor &scale,
const DenseTensor &bias, const DenseTensor &bias,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &saved_mean, const DenseTensor &saved_mean,
const DenseTensor &saved_variance, const DenseTensor &saved_variance,
paddle::optional<const DenseTensor &> reserve_space, paddle::optional<const DenseTensor &> reserve_space,
paddle::optional<const DenseTensor &> mean, const DenseTensor &y_grad,
paddle::optional<const DenseTensor &> variance,
float momentum, float momentum,
float epsilon_f, float epsilon_f,
const std::string &data_layout_str, const std::string &data_layout_str,
...@@ -863,15 +863,15 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -863,15 +863,15 @@ void BatchNormGradRawKernel(const Context &ctx,
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormGradKernel(const Context &dev_ctx, void BatchNormGradKernel(const Context &dev_ctx,
const DenseTensor &y_grad,
const DenseTensor &x, const DenseTensor &x,
const DenseTensor &scale, const DenseTensor &scale,
const DenseTensor &bias, const DenseTensor &bias,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &saved_mean, const DenseTensor &saved_mean,
const DenseTensor &saved_variance, const DenseTensor &saved_variance,
paddle::optional<const DenseTensor &> reserve_space, paddle::optional<const DenseTensor &> reserve_space,
paddle::optional<const DenseTensor &> mean, const DenseTensor &y_grad,
paddle::optional<const DenseTensor &> variance,
float momentum, float momentum,
float epsilon, float epsilon,
const std::string &data_layout, const std::string &data_layout,
...@@ -883,15 +883,15 @@ void BatchNormGradKernel(const Context &dev_ctx, ...@@ -883,15 +883,15 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *scale_grad, DenseTensor *scale_grad,
DenseTensor *bias_grad) { DenseTensor *bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx, BatchNormGradRawKernel<T, Context>(dev_ctx,
y_grad,
x, x,
scale, scale,
bias, bias,
mean,
variance,
saved_mean, saved_mean,
saved_variance, saved_variance,
reserve_space, reserve_space,
mean, y_grad,
variance,
momentum, momentum,
epsilon, epsilon,
data_layout, data_layout,
......
...@@ -59,15 +59,17 @@ KernelSignature BatchNormGradOpArgumentMapping( ...@@ -59,15 +59,17 @@ KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"batch_norm_grad", "batch_norm_grad",
{GradVarName("Y"), {
"X", "X",
"Scale", "Scale",
"Bias", "Bias",
"Mean",
"Variance",
"SavedMean", "SavedMean",
"SavedVariance", "SavedVariance",
"ReserveSpace", "ReserveSpace",
"Mean", GradVarName("Y"),
"Variance"}, },
{"momentum", {"momentum",
"epsilon", "epsilon",
"data_layout", "data_layout",
......
...@@ -1339,12 +1339,19 @@ class BatchNorm(layers.Layer): ...@@ -1339,12 +1339,19 @@ class BatchNorm(layers.Layer):
variance_out = self._variance variance_out = self._variance
if _non_static_mode(): if _non_static_mode():
if in_dygraph_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.final_state_batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_layout,
not self.training, self._use_global_stats,
self._trainable_statistics, False)
else:
attrs = ("momentum", self._momentum, "epsilon", self._epsilon, attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", "is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", self._use_mkldnn, self._data_layout, "use_mkldnn", self._use_mkldnn,
"fuse_with_relu", self._fuse_with_relu, "use_global_stats", "fuse_with_relu", self._fuse_with_relu,
self._use_global_stats, 'trainable_statistics', "use_global_stats", self._use_global_stats,
self._trainable_statistics) 'trainable_statistics', self._trainable_statistics)
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance, input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs) mean_out, variance_out, *attrs)
......
...@@ -81,6 +81,40 @@ class TestBatchNorm(unittest.TestCase): ...@@ -81,6 +81,40 @@ class TestBatchNorm(unittest.TestCase):
self.assertRaises(ValueError, error2d_dataformat) self.assertRaises(ValueError, error2d_dataformat)
self.assertRaises(ValueError, error3d_dataformat) self.assertRaises(ValueError, error3d_dataformat)
def test_eager_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
def compute_v1(x):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(shape[1])
#bn = paddle.nn.BatchNorm2D(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
def compute_v2(x):
with fluid.dygraph.guard(p):
with _test_eager_guard():
print("v2")
bn = paddle.nn.BatchNorm2D(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
x = np.random.randn(*shape).astype("float32")
y1, g1 = compute_v1(x)
y2, g2 = compute_v2(x)
self.assertTrue(np.allclose(g1, g2))
self.assertTrue(np.allclose(y1, y2))
def test_dygraph(self): def test_dygraph(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
...@@ -186,15 +186,24 @@ def batch_norm(x, ...@@ -186,15 +186,24 @@ def batch_norm(x,
else: else:
trainable_statistics = not use_global_stats trainable_statistics = not use_global_stats
if in_dynamic_mode(): if in_dygraph_mode():
batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm(
x, weight, bias, running_mean, running_var, momentum, epsilon,
data_format, not training, use_global_stats, trainable_statistics,
False)
return batch_norm_out
if _in_legacy_dygraph():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test", attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False, not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats, "fuse_with_relu", False, "use_global_stats", use_global_stats,
"trainable_statistics", trainable_statistics) "trainable_statistics", trainable_statistics)
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm( batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out, x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs) *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=None) batch_norm_out, act=None)
......
...@@ -207,6 +207,13 @@ ...@@ -207,6 +207,13 @@
kernel : kernel :
func : auc func : auc
# batch_norm
- api : batch_norm
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
invoke : batch_norm_impl(x, scale, bias, mean, variance, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics, fuse_with_relu)
backward : batch_norm_grad
- api : bce_loss - api : bce_loss
args : (Tensor input, Tensor label) args : (Tensor input, Tensor label)
output : Tensor output : Tensor
......
...@@ -118,6 +118,18 @@ ...@@ -118,6 +118,18 @@
kernel : kernel :
func : atanh_grad func : atanh_grad
- backward_api : batch_norm_grad
forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : batch_norm_grad
data_type : out_grad
optional : mean_out, variance_out, reserve_space
- backward_api : bce_loss_grad - backward_api : bce_loss_grad
forward : bce_loss (Tensor input, Tensor label) -> Tensor(out) forward : bce_loss (Tensor input, Tensor label) -> Tensor(out)
args : (Tensor input, Tensor label, Tensor out_grad) args : (Tensor input, Tensor label, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册