提交 24f85431 编写于 作者: H HaoRen 提交者: Yi Liu

Add center Loss Op Support (#18681)

* support center loss
* change tensor copy  api to high level api tensorcopy

* test=develop rewrite the center_loss cuda_kernel to make it faster
and add document of the center loss api,also update test function

* test=document_preview test=develop
update document of center loss

* test=document_preview test=develop
modify API.spec modify test code remove nouse const_cast
上级 d21c3914
......@@ -94,6 +94,7 @@ paddle.fluid.initializer.init_on_cpu (ArgSpec(args=[], varargs=None, keywords=No
paddle.fluid.initializer.NumpyArrayInitializer ('paddle.fluid.initializer.NumpyArrayInitializer', ('document', '064f134a27c16372967d450f499762ab'))
paddle.fluid.initializer.NumpyArrayInitializer.__init__ (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)), ('document', '1c74f52549814235077ecc34856a95eb'))
paddle.fluid.layers.center_loss (ArgSpec(args=['input', 'label', 'num_classes', 'alpha', 'param_attr', 'update_center'], varargs=None, keywords=None, defaults=(True,)), ('document', '7129819d94625c6104054e8187768589'))
paddle.fluid.layers.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', '1b4916f765620374ad0fdefe5a352993'))
paddle.fluid.layers.dynamic_lstm (ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)), ('document', '6d3ee14da70adfa36d85c40b18716ef2'))
paddle.fluid.layers.dynamic_lstmp (ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None)), ('document', 'c37d51aad655c8a9f9b045c64717320a'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/center_loss_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
class CenterLossOp : public framework::OperatorWithKernel {
public:
CenterLossOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CenterLoss should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("CenterUpdateRate"),
"Input(CenterUpdateRate) of CenterLoss should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of CenterLoss should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Centers"),
"Input(Centers) of CenterLoss should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SampleCenterDiff"),
"Output(SampleCenterDiff) of CenterLoss should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of CenterLoss should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("CentersOut"),
"Output(CentersOut) of CenterLoss shared data with Centers.");
ctx->SetOutputDim("SampleCenterDiff",
{x_dims[0], product(x_dims) / x_dims[0]});
ctx->SetOutputDim("CentersOut", ctx->GetInputDim("Centers"));
ctx->SetOutputDim("Loss", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Loss");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
class CenterLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of center_loss operator.");
AddInput("Label", "(Tensor) Input tensor of center_loss operator.");
AddInput("Centers", "(Tensor) Input tensor of center_loss operator.");
AddInput("CenterUpdateRate",
"(Tensor) Input tensor of center_loss operator.");
AddOutput("CentersOut", "(Tensor) Input tensor of center_loss operator.");
AddOutput("SampleCenterDiff",
"(Tensor) output tensor of center_loss operator.");
AddOutput("Loss", "(Tensor) Output tensor of center_loss operator.");
AddAttr<int>("cluster_num",
"The output cluster num of the center_loss operator.");
AddAttr<bool>("need_update", "whether need to update center info.");
AddComment(R"DOC(
**CenterLoss operator**
implemention of the center loss function in the papper<<A Discriminative
Feature Learning Approach for Deep Face Recognition>>, equations in this implement
is:loss = 1/2 * (x-y)^2 ,where x(X) means the deep feature(output of last hidden layer )
and y(Label) the target label
)DOC");
}
};
class CenterLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("SampleCenterDiff"),
"Input(SampleCenterDiff) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>("SampleCenterDiff")->type(), ctx.device_context());
}
};
class CenterLossOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("center_loss_grad");
retv->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
retv->SetInput("SampleCenterDiff", Output("SampleCenterDiff"));
retv->SetInput("X", Input("X"));
retv->SetOutput(framework::GradVarName("X"), InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
ops::CenterLossOpGradMaker);
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp);
REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>,
ops::CenterLossKernel<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(center_loss_grad,
ops::CenterLossGradKernel<CPUCtx, float>,
ops::CenterLossGradKernel<CPUCtx, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "paddle/fluid/operators/center_loss_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void ComputeDifferent(T *centers_diff, const T *X, const T *centers,
const int64_t *ids, const int64_t N,
const int64_t K, const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
T *out = centers_diff + idy * D;
const T *x = X + idy * D;
const T *cent = centers + id * D;
for (int i = idx; i < D; i += BlockDimX) {
out[i] = x[i] - cent[i];
}
idy += BlockDimY * GridDimX;
}
}
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void UpdateCenters(T *centers, T *centers_diff, const int64_t *ids,
const int64_t N, const int64_t K, const int64_t D,
const T *alpha) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
int count;
while (idy < K) {
int count = 1;
int64_t id = ids[idy];
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
for (int i = 0; i < K; i++) {
if (ids[i] == id) {
count++;
}
}
const T *diff = centers_diff + idy * D;
T *cent = centers + id * D;
for (int i = idx; i < D; i += BlockDimX) {
paddle::platform::CudaAtomicAdd(&cent[i], alpha[0] * diff[i] / count);
}
idy += BlockDimY * GridDimX;
}
}
template <typename DeviceContext, typename T>
class CenterLossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &device_context = ctx.template device_context<DeviceContext>();
auto stream = device_context.stream();
auto *X = ctx.Input<Tensor>("X"); // deep feature
auto *labels = ctx.Input<Tensor>("Label");
auto *centers = ctx.Input<Tensor>("Centers");
auto *update_rate = ctx.Input<Tensor>("CenterUpdateRate");
int cluster_num = ctx.Attr<int>("cluster_num");
auto *lr_center = update_rate->data<T>();
bool need_update = static_cast<T>(ctx.Attr<bool>("need_update"));
auto x_data = X->data<T>();
auto label_data = labels->data<int64_t>();
auto x_dims = X->dims();
int batch_size = x_dims[0];
const int deep_feat_dim = x_dims[1];
auto *centers_diff = ctx.Output<Tensor>("SampleCenterDiff");
auto centers_diff_data = centers_diff->mutable_data<T>(ctx.GetPlace());
auto centers_data = centers->data<T>();
auto centers_dim = centers->dims();
auto *out_loss = ctx.Output<Tensor>("Loss");
auto loss_data = out_loss->mutable_data<T>(ctx.GetPlace());
auto *centers_out = ctx.Output<Tensor>("CentersOut");
auto *centers_out_data = centers_out->mutable_data<T>(ctx.GetPlace());
auto ctx_place = ctx.GetPlace();
if (centers != centers_out) {
framework::TensorCopy(
*static_cast<const framework::Tensor *>(centers), ctx_place,
*platform::DeviceContextPool::Instance().Get(ctx_place),
static_cast<framework::Tensor *>(centers_out));
}
int64_t numel = X->numel();
size_t N = centers->dims()[0];
size_t D = centers->dims()[1];
size_t K = labels->numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
ComputeDifferent<T, 128, 8, 8><<<grids, threads, 0, stream>>>(
centers_diff_data, x_data, centers_data, label_data, N, K, D);
auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
auto sub_result = EigenMatrix<T>::From(*centers_diff);
auto sub_res_pow2 = (sub_result * sub_result) / T(2.0);
auto z = EigenVector<T>::Flatten(*out_loss);
z.device(place) = sub_res_pow2.sum(Eigen::array<int, 1>({{1}}));
if (need_update) {
UpdateCenters<T, 128, 8, 8><<<grids, threads, 0, stream>>>(
centers_out_data, centers_diff_data, label_data, N, K, D, lr_center);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using GPUCtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(center_loss, ops::CenterLossCUDAKernel<GPUCtx, float>,
ops::CenterLossCUDAKernel<GPUCtx, double>);
REGISTER_OP_CUDA_KERNEL(center_loss_grad,
ops::CenterLossGradKernel<GPUCtx, float>,
ops::CenterLossGradKernel<GPUCtx, double>);
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstring>
#include <limits>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename DeviceContext, typename T>
class CenterLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); // deep feature
auto *labels = ctx.Input<Tensor>("Label");
auto *centers = ctx.Input<Tensor>("Centers");
auto *update_rate = ctx.Input<Tensor>("CenterUpdateRate");
int cluster_num = ctx.Attr<int>("cluster_num");
auto *lr_center = update_rate->data<T>();
T alpha = lr_center[0];
bool need_update = static_cast<T>(ctx.Attr<bool>("need_update"));
auto x_data = X->data<T>();
auto label_data = labels->data<int64_t>();
auto centers_dim = centers->dims();
auto centers_data = centers->data<T>();
auto x_dims = X->dims();
int batch_size = x_dims[0];
int deep_feat_dim = x_dims[1];
auto centers_diff = ctx.Output<Tensor>("SampleCenterDiff");
auto centers_diff_data = centers_diff->mutable_data<T>(ctx.GetPlace());
auto *out_loss = ctx.Output<Tensor>("Loss");
auto *centers_out = ctx.Output<Tensor>("CentersOut");
auto *centers_out_data = centers_out->mutable_data<T>(ctx.GetPlace());
if (centers_out_data != centers_data) {
int size = centers_out->numel() * sizeof(T);
memcpy(centers_out_data, centers_data, size);
}
std::vector<int> center_update_count(cluster_num, 1);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
auto loss_data = out_loss->mutable_data<T>(ctx.GetPlace());
Tensor centers_diffacc; // used to accumulate all diff
auto centers_diffacc_data =
centers_diffacc.mutable_data<T>(centers_dim, ctx.GetPlace());
int numel = centers_diffacc.numel();
std::memset(centers_diffacc_data, 0, sizeof(T) * numel);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
int tLabel;
const T *x_index;
const T *center_index;
T *center_out_index;
T *center_loss_diff_index;
T *acc_index;
platform::Transform<DeviceContext> trans;
for (int i = 0; i < batch_size; ++i) {
tLabel = label_data[i];
center_update_count[tLabel]++;
x_index = x_data + i * deep_feat_dim; // xi index
center_index = centers_data + tLabel * deep_feat_dim; // center index
center_loss_diff_index = centers_diff_data + i * deep_feat_dim;
trans(dev_ctx, x_index, x_index + deep_feat_dim, center_index,
center_loss_diff_index, SubFunctor<T>());
acc_index = centers_diffacc_data + tLabel * deep_feat_dim;
blas.VADD(deep_feat_dim, center_loss_diff_index, acc_index,
acc_index); // accumulate
loss_data[i] = blas.DOT(deep_feat_dim, center_loss_diff_index,
center_loss_diff_index) /
T(2.0);
}
// update centers data
if (need_update == true) {
for (int i = 0; i < cluster_num; i++) {
acc_index = centers_diffacc_data + i * deep_feat_dim;
center_out_index = centers_out_data + i * deep_feat_dim;
T scale = alpha / center_update_count[i];
blas.SCAL(deep_feat_dim, scale, acc_index);
blas.VADD(deep_feat_dim, acc_index, center_out_index, center_out_index);
}
}
}
};
template <typename DeviceContext, typename T>
class CenterLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in0 = context.Input<Tensor>("SampleCenterDiff");
auto *in1 = context.Input<Tensor>(framework::GradVarName("Loss"));
auto *x_g = context.Output<Tensor>(framework::GradVarName("X"));
auto sub_result = EigenMatrix<T>::From(*in0);
auto out_grad = EigenMatrix<T>::From(*in1);
auto x_dims = x_g->dims();
int cols = x_g->numel() / x_dims[0];
// calculate gradient
auto grad_mat =
(out_grad.broadcast(Eigen::array<int, 2>({{1, cols}}))) * sub_result;
// propagate back to input
auto &eigen_place =
*context.template device_context<DeviceContext>().eigen_device();
x_g->mutable_data<T>(context.GetPlace());
// eigen matrix
auto x_grad =
EigenMatrix<T>::From(*x_g, framework::make_ddim({x_dims[0], cols}));
x_grad.device(eigen_place) = grad_mat;
}
};
} // namespace operators
} // namespace paddle
......@@ -37,6 +37,7 @@ from ..dygraph import layers
__all__ = [
'fc',
'center_loss',
'embedding',
'dynamic_lstm',
'dynamic_lstmp',
......@@ -354,6 +355,92 @@ def fc(input,
return helper.append_activation(pre_activation)
def center_loss(input,
label,
num_classes,
alpha,
param_attr,
update_center=True):
"""
**Center loss Cost layer**
This layer accepts input (deep features,the output of the last hidden layer)
and target label and return the center loss cost
For deep features, :math:`X`, and target labels, :math:`Y`, the equation is:
.. math::
Out = \\frac{1}{2}(X - Y)^2
Args:
input (Variable): a 2-D tensor with shape[N x M].
label (Variable): the groud truth which is a 2-D tensor
with shape[N x 1],where N is the batch size.
num_classes (int): the number of classification categories.
alpha (float|Variable): learning rate of centers.
param_attr (ParamAttr): Attribute initializer of centers.
update_center (bool): whether to update value of center.
Returns:
Variable: 2-D tensor with shape [N * 1]
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name='x',shape=[20,30],dtype='float32')
label = fluid.layers.data(name='y',shape=[20,1],dtype='int64')
num_classes = 1000
alpha = 0.01
param_attr = fluid.initializer.Xavier(uniform=False)
center_loss=fluid.layers.center_loss(input=input,
label=label,
num_classes=1000,
alpha=alpha,
param_attr=fluid.initializer.Xavier(uniform=False),
update_center=True)
"""
helper = LayerHelper('center_loss', **locals())
dtype = helper.input_dtype()
centers_shape = [num_classes, input.shape[1]]
centers_param = helper.create_parameter(
attr=param_attr, shape=centers_shape, dtype=dtype)
centers_param.stop_gradient = True
if isinstance(alpha, Variable):
alpha_param = alpha
else:
assert isinstance(alpha, float)
alpha_param = helper.create_variable(
name="centerloss_alpha",
shape=[1],
dtype="float32",
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=True,
stop_gradient=True,
initializer=Constant(alpha))
centersdiff = helper.create_variable_for_type_inference(dtype=input.dtype)
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='center_loss',
inputs={
'X': [input],
'Label': [label],
'Centers': [centers_param],
'CenterUpdateRate': [alpha_param]
},
outputs={
'SampleCenterDiff': [centersdiff],
'Loss': [loss],
'CentersOut': [centers_param]
},
attrs={'cluster_num': num_classes,
'need_update': update_center})
return loss
def embedding(input,
size,
is_sparse=False,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
class TestCenterLossOp(OpTest):
def setUp(self):
self.op_type = "center_loss"
self.dtype = np.float32
self.init_dtype_type()
batch_size = 6
feet_dim = 10
cluster_num = 8
self.attrs = {}
self.attrs['cluster_num'] = cluster_num
self.attrs['lambda'] = 0.1
self.config()
self.attrs['need_update'] = self.need_update
labels = np.random.randint(cluster_num, size=batch_size, dtype='int64')
feat = np.random.random((batch_size, feet_dim)).astype(np.float32)
centers = np.random.random((cluster_num, feet_dim)).astype(np.float32)
var_sum = np.zeros((cluster_num, feet_dim), dtype=np.float32)
centers_select = centers[labels]
output = feat - centers_select
diff_square = np.square(output).reshape(batch_size, feet_dim)
loss = 0.5 * np.sum(diff_square, axis=1).reshape(batch_size, 1)
cout = []
for i in range(cluster_num):
cout.append(0)
for i in range(batch_size):
cout[labels[i]] += 1
var_sum[labels[i]] += output[i]
for i in range(cluster_num):
var_sum[i] /= (1 + cout[i])
var_sum *= 0.1
result = centers + var_sum
rate = np.array([0.1]).astype(np.float32)
self.inputs = {
'X': feat,
'Label': labels,
'Centers': centers,
'CenterUpdateRate': rate
}
if self.need_update == True:
self.outputs = {
'SampleCenterDiff': output,
'Loss': loss,
'CentersOut': result
}
else:
self.outputs = {
'SampleCenterDiff': output,
'Loss': loss,
'CentersOut': centers
}
def config(self):
self.need_update = True
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Loss')
class TestCenterLossOpNoUpdate(TestCenterLossOp):
def config(self):
self.need_update = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册