diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 95156274acecb475af54d71863ef899c24025ba7..cf825998979f15a797ea7c501918f4c9490876cf 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -94,6 +94,7 @@ paddle.fluid.initializer.init_on_cpu (ArgSpec(args=[], varargs=None, keywords=No paddle.fluid.initializer.NumpyArrayInitializer ('paddle.fluid.initializer.NumpyArrayInitializer', ('document', '064f134a27c16372967d450f499762ab')) paddle.fluid.initializer.NumpyArrayInitializer.__init__ (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)), ('document', '1c74f52549814235077ecc34856a95eb')) +paddle.fluid.layers.center_loss (ArgSpec(args=['input', 'label', 'num_classes', 'alpha', 'param_attr', 'update_center'], varargs=None, keywords=None, defaults=(True,)), ('document', '7129819d94625c6104054e8187768589')) paddle.fluid.layers.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', '1b4916f765620374ad0fdefe5a352993')) paddle.fluid.layers.dynamic_lstm (ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)), ('document', '6d3ee14da70adfa36d85c40b18716ef2')) paddle.fluid.layers.dynamic_lstmp (ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)), ('document', 'c37d51aad655c8a9f9b045c64717320a')) diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf766a056a767f4b5e152800e9305d1f51f6d901 --- /dev/null +++ b/paddle/fluid/operators/center_loss_op.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/center_loss_op.h" +#include +#include + +namespace paddle { +namespace operators { +class CenterLossOp : public framework::OperatorWithKernel { + public: + CenterLossOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of CenterLoss should not be null."); + auto x_dims = ctx->GetInputDim("X"); + + PADDLE_ENFORCE(ctx->HasInput("CenterUpdateRate"), + "Input(CenterUpdateRate) of CenterLoss should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of CenterLoss should not be null."); + + PADDLE_ENFORCE(ctx->HasInput("Centers"), + "Input(Centers) of CenterLoss should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("SampleCenterDiff"), + "Output(SampleCenterDiff) of CenterLoss should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Loss"), + "Output(Loss) of CenterLoss should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("CentersOut"), + "Output(CentersOut) of CenterLoss shared data with Centers."); + + ctx->SetOutputDim("SampleCenterDiff", + {x_dims[0], product(x_dims) / x_dims[0]}); + ctx->SetOutputDim("CentersOut", ctx->GetInputDim("Centers")); + ctx->SetOutputDim("Loss", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Loss"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class CenterLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of center_loss operator."); + AddInput("Label", "(Tensor) Input tensor of center_loss operator."); + AddInput("Centers", "(Tensor) Input tensor of center_loss operator."); + AddInput("CenterUpdateRate", + "(Tensor) Input tensor of center_loss operator."); + + AddOutput("CentersOut", "(Tensor) Input tensor of center_loss operator."); + AddOutput("SampleCenterDiff", + "(Tensor) output tensor of center_loss operator."); + AddOutput("Loss", "(Tensor) Output tensor of center_loss operator."); + + AddAttr("cluster_num", + "The output cluster num of the center_loss operator."); + AddAttr("need_update", "whether need to update center info."); + AddComment(R"DOC( +**CenterLoss operator** +implemention of the center loss function in the papper<>, equations in this implement +is:loss = 1/2 * (x-y)^2 ,where x(X) means the deep feature(output of last hidden layer ) +and y(Label) the target label +)DOC"); + } +}; + +class CenterLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("SampleCenterDiff"), + "Input(SampleCenterDiff) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + ctx.Input("SampleCenterDiff")->type(), ctx.device_context()); + } +}; + +class CenterLossOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new framework::OpDesc()); + retv->SetType("center_loss_grad"); + retv->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + retv->SetInput("SampleCenterDiff", Output("SampleCenterDiff")); + retv->SetInput("X", Input("X")); + retv->SetOutput(framework::GradVarName("X"), InputGrad("X")); + + retv->SetAttrMap(Attrs()); + return retv; + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker, + ops::CenterLossOpGradMaker); + +REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp); + +REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel, + ops::CenterLossKernel); + +REGISTER_OP_CPU_KERNEL(center_loss_grad, + ops::CenterLossGradKernel, + ops::CenterLossGradKernel); diff --git a/paddle/fluid/operators/center_loss_op.cu b/paddle/fluid/operators/center_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb172fb1f1e82ae3da2ce7d5cf4a76eb5a43e0dc --- /dev/null +++ b/paddle/fluid/operators/center_loss_op.cu @@ -0,0 +1,147 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/center_loss_op.h" +#include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void ComputeDifferent(T *centers_diff, const T *X, const T *centers, + const int64_t *ids, const int64_t N, + const int64_t K, const int64_t D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * GridDimX; + + while (idy < K) { + int64_t id = ids[idy]; + PADDLE_ASSERT_MSG(id >= 0, "received id:", id); + PADDLE_ASSERT_MSG(id < N, "received id:", id); + T *out = centers_diff + idy * D; + const T *x = X + idy * D; + const T *cent = centers + id * D; + for (int i = idx; i < D; i += BlockDimX) { + out[i] = x[i] - cent[i]; + } + idy += BlockDimY * GridDimX; + } +} + +template +__global__ void UpdateCenters(T *centers, T *centers_diff, const int64_t *ids, + const int64_t N, const int64_t K, const int64_t D, + const T *alpha) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * GridDimX; + int count; + while (idy < K) { + int count = 1; + int64_t id = ids[idy]; + PADDLE_ASSERT_MSG(id >= 0, "received id:", id); + PADDLE_ASSERT_MSG(id < N, "received id:", id); + + for (int i = 0; i < K; i++) { + if (ids[i] == id) { + count++; + } + } + const T *diff = centers_diff + idy * D; + T *cent = centers + id * D; + for (int i = idx; i < D; i += BlockDimX) { + paddle::platform::CudaAtomicAdd(¢[i], alpha[0] * diff[i] / count); + } + idy += BlockDimY * GridDimX; + } +} + +template +class CenterLossCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &device_context = ctx.template device_context(); + auto stream = device_context.stream(); + auto *X = ctx.Input("X"); // deep feature + auto *labels = ctx.Input("Label"); + auto *centers = ctx.Input("Centers"); + auto *update_rate = ctx.Input("CenterUpdateRate"); + int cluster_num = ctx.Attr("cluster_num"); + auto *lr_center = update_rate->data(); + bool need_update = static_cast(ctx.Attr("need_update")); + + auto x_data = X->data(); + auto label_data = labels->data(); + + auto x_dims = X->dims(); + int batch_size = x_dims[0]; + const int deep_feat_dim = x_dims[1]; + + auto *centers_diff = ctx.Output("SampleCenterDiff"); + auto centers_diff_data = centers_diff->mutable_data(ctx.GetPlace()); + + auto centers_data = centers->data(); + auto centers_dim = centers->dims(); + auto *out_loss = ctx.Output("Loss"); + auto loss_data = out_loss->mutable_data(ctx.GetPlace()); + + auto *centers_out = ctx.Output("CentersOut"); + auto *centers_out_data = centers_out->mutable_data(ctx.GetPlace()); + + auto ctx_place = ctx.GetPlace(); + if (centers != centers_out) { + framework::TensorCopy( + *static_cast(centers), ctx_place, + *platform::DeviceContextPool::Instance().Get(ctx_place), + static_cast(centers_out)); + } + + int64_t numel = X->numel(); + + size_t N = centers->dims()[0]; + size_t D = centers->dims()[1]; + size_t K = labels->numel(); + + dim3 threads(128, 8); + dim3 grids(8, 1); + + ComputeDifferent<<>>( + centers_diff_data, x_data, centers_data, label_data, N, K, D); + + auto &place = *ctx.template device_context().eigen_device(); + auto sub_result = EigenMatrix::From(*centers_diff); + + auto sub_res_pow2 = (sub_result * sub_result) / T(2.0); + auto z = EigenVector::Flatten(*out_loss); + z.device(place) = sub_res_pow2.sum(Eigen::array({{1}})); + if (need_update) { + UpdateCenters<<>>( + centers_out_data, centers_diff_data, label_data, N, K, D, lr_center); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using GPUCtx = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(center_loss, ops::CenterLossCUDAKernel, + ops::CenterLossCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(center_loss_grad, + ops::CenterLossGradKernel, + ops::CenterLossGradKernel); diff --git a/paddle/fluid/operators/center_loss_op.h b/paddle/fluid/operators/center_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f134bd0cd3c7a565019c92bf08ee4c565ba67ac5 --- /dev/null +++ b/paddle/fluid/operators/center_loss_op.h @@ -0,0 +1,155 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/functors.h" +#include "paddle/fluid/platform/transform.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct SubFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } +}; + +template +class CenterLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); // deep feature + auto *labels = ctx.Input("Label"); + auto *centers = ctx.Input("Centers"); + auto *update_rate = ctx.Input("CenterUpdateRate"); + int cluster_num = ctx.Attr("cluster_num"); + auto *lr_center = update_rate->data(); + T alpha = lr_center[0]; + bool need_update = static_cast(ctx.Attr("need_update")); + + auto x_data = X->data(); + auto label_data = labels->data(); + + auto centers_dim = centers->dims(); + auto centers_data = centers->data(); + + auto x_dims = X->dims(); + int batch_size = x_dims[0]; + int deep_feat_dim = x_dims[1]; + + auto centers_diff = ctx.Output("SampleCenterDiff"); + auto centers_diff_data = centers_diff->mutable_data(ctx.GetPlace()); + auto *out_loss = ctx.Output("Loss"); + + auto *centers_out = ctx.Output("CentersOut"); + auto *centers_out_data = centers_out->mutable_data(ctx.GetPlace()); + + if (centers_out_data != centers_data) { + int size = centers_out->numel() * sizeof(T); + memcpy(centers_out_data, centers_data, size); + } + + std::vector center_update_count(cluster_num, 1); + auto &dev_ctx = ctx.template device_context(); + + auto loss_data = out_loss->mutable_data(ctx.GetPlace()); + + Tensor centers_diffacc; // used to accumulate all diff + auto centers_diffacc_data = + centers_diffacc.mutable_data(centers_dim, ctx.GetPlace()); + int numel = centers_diffacc.numel(); + std::memset(centers_diffacc_data, 0, sizeof(T) * numel); + + auto blas = math::GetBlas(ctx); + int tLabel; + + const T *x_index; + const T *center_index; + T *center_out_index; + T *center_loss_diff_index; + T *acc_index; + platform::Transform trans; + + for (int i = 0; i < batch_size; ++i) { + tLabel = label_data[i]; + center_update_count[tLabel]++; + x_index = x_data + i * deep_feat_dim; // xi index + center_index = centers_data + tLabel * deep_feat_dim; // center index + center_loss_diff_index = centers_diff_data + i * deep_feat_dim; + trans(dev_ctx, x_index, x_index + deep_feat_dim, center_index, + center_loss_diff_index, SubFunctor()); + + acc_index = centers_diffacc_data + tLabel * deep_feat_dim; + blas.VADD(deep_feat_dim, center_loss_diff_index, acc_index, + acc_index); // accumulate + loss_data[i] = blas.DOT(deep_feat_dim, center_loss_diff_index, + center_loss_diff_index) / + T(2.0); + } + + // update centers data + if (need_update == true) { + for (int i = 0; i < cluster_num; i++) { + acc_index = centers_diffacc_data + i * deep_feat_dim; + center_out_index = centers_out_data + i * deep_feat_dim; + T scale = alpha / center_update_count[i]; + blas.SCAL(deep_feat_dim, scale, acc_index); + blas.VADD(deep_feat_dim, acc_index, center_out_index, center_out_index); + } + } + } +}; + +template +class CenterLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *in0 = context.Input("SampleCenterDiff"); + auto *in1 = context.Input(framework::GradVarName("Loss")); + auto *x_g = context.Output(framework::GradVarName("X")); + auto sub_result = EigenMatrix::From(*in0); + auto out_grad = EigenMatrix::From(*in1); + + auto x_dims = x_g->dims(); + int cols = x_g->numel() / x_dims[0]; + // calculate gradient + auto grad_mat = + (out_grad.broadcast(Eigen::array({{1, cols}}))) * sub_result; + + // propagate back to input + auto &eigen_place = + *context.template device_context().eigen_device(); + x_g->mutable_data(context.GetPlace()); + // eigen matrix + auto x_grad = + EigenMatrix::From(*x_g, framework::make_ddim({x_dims[0], cols})); + x_grad.device(eigen_place) = grad_mat; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8667b321871ad6d4c4df00319554a8f7b8ec0a3c..ff154c23937648c4fa2e7f8d5c742958ca0f299c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -37,6 +37,7 @@ from ..dygraph import layers __all__ = [ 'fc', + 'center_loss', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', @@ -354,6 +355,92 @@ def fc(input, return helper.append_activation(pre_activation) +def center_loss(input, + label, + num_classes, + alpha, + param_attr, + update_center=True): + """ + **Center loss Cost layer** + + This layer accepts input (deep features,the output of the last hidden layer) + and target label and return the center loss cost + + For deep features, :math:`X`, and target labels, :math:`Y`, the equation is: + + .. math:: + + Out = \\frac{1}{2}(X - Y)^2 + + Args: + input (Variable): a 2-D tensor with shape[N x M]. + label (Variable): the groud truth which is a 2-D tensor + with shape[N x 1],where N is the batch size. + num_classes (int): the number of classification categories. + alpha (float|Variable): learning rate of centers. + param_attr (ParamAttr): Attribute initializer of centers. + update_center (bool): whether to update value of center. + + Returns: + Variable: 2-D tensor with shape [N * 1] + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + input = fluid.layers.data(name='x',shape=[20,30],dtype='float32') + label = fluid.layers.data(name='y',shape=[20,1],dtype='int64') + num_classes = 1000 + alpha = 0.01 + param_attr = fluid.initializer.Xavier(uniform=False) + center_loss=fluid.layers.center_loss(input=input, + label=label, + num_classes=1000, + alpha=alpha, + param_attr=fluid.initializer.Xavier(uniform=False), + update_center=True) + """ + helper = LayerHelper('center_loss', **locals()) + dtype = helper.input_dtype() + centers_shape = [num_classes, input.shape[1]] + centers_param = helper.create_parameter( + attr=param_attr, shape=centers_shape, dtype=dtype) + centers_param.stop_gradient = True + if isinstance(alpha, Variable): + alpha_param = alpha + else: + assert isinstance(alpha, float) + alpha_param = helper.create_variable( + name="centerloss_alpha", + shape=[1], + dtype="float32", + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=True, + stop_gradient=True, + initializer=Constant(alpha)) + + centersdiff = helper.create_variable_for_type_inference(dtype=input.dtype) + loss = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='center_loss', + inputs={ + 'X': [input], + 'Label': [label], + 'Centers': [centers_param], + 'CenterUpdateRate': [alpha_param] + }, + outputs={ + 'SampleCenterDiff': [centersdiff], + 'Loss': [loss], + 'CentersOut': [centers_param] + }, + attrs={'cluster_num': num_classes, + 'need_update': update_center}) + return loss + + def embedding(input, size, is_sparse=False, diff --git a/python/paddle/fluid/tests/unittests/test_center_loss.py b/python/paddle/fluid/tests/unittests/test_center_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..50dd6b5e940d25fa95b16d53858396fd6fa476f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_center_loss.py @@ -0,0 +1,95 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +class TestCenterLossOp(OpTest): + def setUp(self): + self.op_type = "center_loss" + self.dtype = np.float32 + self.init_dtype_type() + batch_size = 6 + feet_dim = 10 + cluster_num = 8 + self.attrs = {} + self.attrs['cluster_num'] = cluster_num + self.attrs['lambda'] = 0.1 + self.config() + self.attrs['need_update'] = self.need_update + labels = np.random.randint(cluster_num, size=batch_size, dtype='int64') + feat = np.random.random((batch_size, feet_dim)).astype(np.float32) + centers = np.random.random((cluster_num, feet_dim)).astype(np.float32) + var_sum = np.zeros((cluster_num, feet_dim), dtype=np.float32) + centers_select = centers[labels] + output = feat - centers_select + diff_square = np.square(output).reshape(batch_size, feet_dim) + loss = 0.5 * np.sum(diff_square, axis=1).reshape(batch_size, 1) + cout = [] + for i in range(cluster_num): + cout.append(0) + for i in range(batch_size): + cout[labels[i]] += 1 + var_sum[labels[i]] += output[i] + for i in range(cluster_num): + var_sum[i] /= (1 + cout[i]) + var_sum *= 0.1 + result = centers + var_sum + rate = np.array([0.1]).astype(np.float32) + + self.inputs = { + 'X': feat, + 'Label': labels, + 'Centers': centers, + 'CenterUpdateRate': rate + } + + if self.need_update == True: + self.outputs = { + 'SampleCenterDiff': output, + 'Loss': loss, + 'CentersOut': result + } + else: + self.outputs = { + 'SampleCenterDiff': output, + 'Loss': loss, + 'CentersOut': centers + } + + def config(self): + self.need_update = True + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Loss') + + +class TestCenterLossOpNoUpdate(TestCenterLossOp): + def config(self): + self.need_update = False + + +if __name__ == "__main__": + unittest.main()