未验证 提交 67a2b521 编写于 作者: Q qingqing01 提交者: GitHub

Add affine channel op to speed and save memory for faster-rcnn model. (#13919)

* Add affine channel op.
* Update code and add Python API.
test=develop
* Update API.spec
test=develop
上级 30dfbdee
...@@ -173,6 +173,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -173,6 +173,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
...@@ -305,6 +305,7 @@ if (WITH_GPU) ...@@ -305,6 +305,7 @@ if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(conv_op DEPS vol2col depthwise_conv im2col)
op_library(layer_norm_op DEPS cub) op_library(layer_norm_op DEPS cub)
op_library(reduce_mean_op DEPS cub) op_library(reduce_mean_op DEPS cub)
op_library(affine_channel_op DEPS cub)
else() else()
op_library(conv_op DEPS vol2col im2col) op_library(conv_op DEPS vol2col im2col)
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class AffineChannelOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) Feature map input can be a 4D tensor with order NCHW "
"or NHWC. It also can be a 2D tensor and C is the second "
"dimension.");
AddInput("Scale",
"(Tensor) 1D input of shape (C), the c-th element "
"is the scale factor of the affine transformation "
"for the c-th channel of the input.");
AddInput("Bias",
"(Tensor) 1D input of shape (C), the c-th element "
"is the bias of the affine transformation for the "
"c-th channel of the input.");
AddAttr<std::string>(
"data_layout",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddOutput("Out", "(Tensor) A tensor of the same shape and order with X.");
AddComment(R"DOC(
Applies a separate affine transformation to each channel of the input. Useful
for replacing spatial batch norm with its equivalent fixed transformation.
The input also can be 2D tensor and applies a affine transformation in second
dimension.
$$Out = Scale*X + Bias$$
)DOC");
}
};
class AffineChannelOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AffineChannelOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out");
}
};
class AffineChannelOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
// Scale@GRAD and Bias@GRAD must exist at the same time.
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Scale"));
}
}
};
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
class AffineChannelKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* y = ctx.Output<framework::Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* scale_d = scale->data<T>();
auto* bias_d = bias->data<T>();
ConstEigenVectorArrayMap<T> a_e(scale_d, C);
ConstEigenVectorArrayMap<T> b_e(bias_d, C);
auto* x_d = x->data<T>();
auto* y_d = y->data<T>();
if (layout == framework::DataLayout::kNCHW) {
int stride = C * HxW;
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
EigenArrayMap<T> y_e(y_d, HxW, C);
y_e = (x_e.rowwise() * a_e.transpose()).rowwise() + b_e.transpose();
x_d += stride;
y_d += stride;
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> x_e(x_d, C, num);
EigenArrayMap<T> y_e(y_d, C, num);
y_e = (x_e.colwise() * a_e).colwise() + b_e;
}
}
};
template <typename DeviceContext, typename T>
class AffineChannelGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dscale =
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* x_d = x->data<T>();
auto* dy_d = dy->data<T>();
auto* scale_d = scale->data<T>();
ConstEigenVectorArrayMap<T> scale_e(scale_d, C);
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dscale_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dbias_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
EigenVectorArrayMap<T> dscale_e(dscale_d, C);
EigenVectorArrayMap<T> dbias_e(dbias_d, C);
if (layout == framework::DataLayout::kNCHW) {
// compute dx
int stride = C * HxW;
if (dx) {
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
EigenArrayMap<T> dx_e(dx_d, HxW, C);
dx_e = dy_e.rowwise() * scale_e.transpose();
dy_d += stride;
dx_d += stride;
}
}
// compute dscale and dbias
if (dscale && dbias) {
dy_d = dy->data<T>();
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
if (i == 0) {
dscale_e = (x_e * dy_e).colwise().sum();
} else {
dscale_e += (x_e * dy_e).colwise().sum();
}
if (i == 0) {
dbias_e = dy_e.colwise().sum();
} else {
dbias_e += dy_e.colwise().sum();
}
x_d += stride;
dy_d += stride;
}
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> dy_e(dy_d, C, num);
// compute dx
if (dx) {
EigenArrayMap<T> dx_e(dx_d, C, num);
dx_e = dy_e.colwise() * scale_e;
}
// compute dscale and dbias
if (dscale && dbias) {
ConstEigenArrayMap<T> x_e(x_d, C, num);
dscale_e = (x_e * dy_e).rowwise().sum();
dbias_e = dy_e.rowwise().sum();
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
ops::AffineChannelKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradKernel<CPU, float>,
ops::AffineChannelGradKernel<CPU, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
template <typename T, framework::DataLayout layout, bool HasBias>
__global__ void KeAffineChannelCUDA(const T* x, const T* scale, const T* bias,
const int C, const int HxW, const int num,
T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
if (HasBias) {
y[i] = scale[c] * x[i] + bias[c];
} else {
y[i] = scale[c] * x[i];
}
}
}
template <typename DeviceContext, typename T>
class AffineChannelCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* y = ctx.Output<framework::Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dims = x->dims();
const int num = x->numel();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = num / N / C;
const T* x_d = x->data<T>();
const T* scale_d = scale->data<T>();
const T* bias_d = bias->data<T>();
T* y_d = y->data<T>();
int block = 1024;
int grid = (num + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
true><<<grid, block, 0, dev_ctx.stream()>>>(
x_d, scale_d, bias_d, C, HxW, num, y_d);
} else {
KeAffineChannelCUDA<T, framework::DataLayout::kNHWC,
true><<<grid, block, 0, dev_ctx.stream()>>>(
x_d, scale_d, bias_d, C, HxW, num, y_d);
}
}
};
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void AffineChannelScaleBiasGradientCUDAKernel(
const T* dy, const T* x, const int N, const int C, const int HxW, T* dscale,
T* dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T ds_sum = 0;
T db_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] = ds_sum;
dbias[i] = db_sum;
}
__syncthreads();
}
}
template <typename DeviceContext, typename T>
class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dscale =
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dims = x->dims();
const int num = x->numel();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = num / N / C;
const T* x_d = x->data<T>();
const T* dy_d = dy->data<T>();
const T* s_d = scale->data<T>();
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
const int block = 1024;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (layout == framework::DataLayout::kNCHW) {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
if (dscale && dbias) {
AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNCHW><<<grid2, block, 0,
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
} else {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
if (dscale && dbias) {
AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNHWC><<<grid2, block, 0,
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(affine_channel,
ops::AffineChannelCUDAKernel<CUDA, float>,
ops::AffineChannelCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(affine_channel_grad,
ops::AffineChannelGradCUDAKernel<CUDA, float>,
ops::AffineChannelGradCUDAKernel<CUDA, double>);
...@@ -153,6 +153,7 @@ __all__ = [ ...@@ -153,6 +153,7 @@ __all__ = [
'mul', 'mul',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'affine_channel',
] ]
...@@ -7268,3 +7269,44 @@ def maxout(x, groups, name=None): ...@@ -7268,3 +7269,44 @@ def maxout(x, groups, name=None):
attrs={"groups": groups}, attrs={"groups": groups},
outputs={"Out": out}) outputs={"Out": out})
return out return out
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
"""
Applies a separate affine transformation to each channel of the input.
Useful for replacing spatial batch norm with its equivalent fixed
transformation. The input also can be 2D tensor and applies a affine
transformation in second dimension.
Args:
x (Variable): Feature map input can be a 4D tensor with order NCHW
or NHWC. It also can be a 2D tensor and the affine transformation
is applied in the second dimension.
scale (Variable): 1D input of shape (C), the c-th element is the scale
factor of the affine transformation for the c-th channel of
the input.
bias (Variable): 1D input of shape (C), the c-th element is the bias
of the affine transformation for the c-th channel of the input.
data_layout (string, default NCHW): NCHW or NHWC. If input is 2D
tensor, you can ignore data_layout.
name (str, default None): The name of this layer.
Returns:
out (Variable): A tensor of the same shape and data layout with x.
"""
helper = LayerHelper("affine_channel", **locals())
if name is None:
out = helper.create_tmp_variable(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="affine_channel",
inputs={"X": x,
'Scale': scale,
'Bias': bias},
attrs={"data_layout": data_layout},
outputs={"Out": out})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def affine_channel(x, scale, bias, layout):
C = x.shape[1] if layout == 'NCHW' else x.shape[-1]
if len(x.shape) == 4:
new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C)
else:
new_shape = (1, C)
scale = scale.reshape(new_shape)
bias = bias.reshape(new_shape)
return x * scale + bias
class TestAffineChannelOp(OpTest):
def setUp(self):
self.op_type = "affine_channel"
self.init_test_case()
x = np.random.random(self.shape).astype("float32")
scale = np.random.random(self.C).astype("float32")
bias = np.random.random(self.C).astype("float32")
y = affine_channel(x, scale, bias, self.layout)
self.inputs = {'X': x, 'Scale': scale, 'Bias': bias}
self.attrs = {'data_layout': self.layout}
self.outputs = {'Out': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Scale', 'Bias'], 'Out')
def test_check_grad_stopgrad_dx(self):
self.check_grad(['Scale', 'Bias'], 'Out', no_grad_set=set('X'))
def test_check_grad_stopgrad_dscale_dbias(self):
self.check_grad(['X'], 'Out', no_grad_set=set(['Scale', 'Bias']))
def init_test_case(self):
self.shape = [2, 32, 14, 14]
self.C = 32
self.layout = 'NCHW'
class TestAffineChannelNHWC(TestAffineChannelOp):
def init_test_case(self):
self.shape = [2, 14, 14, 32]
self.C = 32
self.layout = 'NHWC'
class TestAffineChannel2D(TestAffineChannelOp):
def init_test_case(self):
self.shape = [16, 64]
self.C = 64
self.layout = 'NCHW'
class TestAffineChannelNCHWLargeShape(TestAffineChannelOp):
def init_test_case(self):
self.shape = [64, 128, 112, 112]
self.C = 128
self.layout = 'NCHW'
# since the gradient check is very slow in large shape, so skip check_grad
def test_check_grad(self):
pass
def test_check_grad_stopgrad_dx(self):
pass
def test_check_grad_stopgrad_dscale_dbias(self):
pass
class TestAffineChannelNCHWLargeShape(TestAffineChannelNCHWLargeShape):
def init_test_case(self):
self.shape = [64, 112, 112, 512]
self.C = 512
self.layout = 'NHWC'
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册