提交 58ad40cc 编写于 作者: X xuezhong

add sample_logits op

上级 b5ebca47
......@@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
......
......@@ -39,6 +39,7 @@ math_library(cross_entropy)
math_library(cos_sim_functor)
math_library(depthwise_conv)
math_library(im2col)
math_library(sample_prob)
math_library(sampler)
math_library(gru_compute DEPS activation_functions math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sample_prob.h"
namespace paddle {
namespace operators {
namespace math {
template class SampleWithProb<platform::CPUDeviceContext, float>;
template class SampleWithProb<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/random.h>
#include <thrust/sort.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
template <typename T>
__device__ T gpu_adjust_prob(const T prob, const int num_samples,
const int num_tries) {
if (num_samples == num_tries) {
return prob * num_samples;
} else {
return -expm1(num_tries * log1p(-prob));
}
}
class GPULogUniformSampler {
public:
__device__ int64_t Sample(float random, const int range,
const float log_range) const;
__device__ float Probability(int64_t value, const float log_range) const;
};
__device__ int64_t GPULogUniformSampler::Sample(float random, const int range,
const float log_range) const {
// Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method
const int64_t value = static_cast<int64_t>(exp(random * log_range)) - 1;
// Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_.
return value % range;
}
__device__ float GPULogUniformSampler::Probability(
int64_t value, const float log_range) const {
// Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1)
return (log((value + 2.0) / (value + 1.0))) / log_range;
}
template <typename T>
__global__ void SamplingCondidate(
const size_t n, const int num_tries, const int range, const float log_range,
const int num_true, const std::size_t num_samples,
const int64_t* label_data, int64_t* samples_data, T* probabilities_data) {
const int num_sampled_classes = num_true + num_samples;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
GPULogUniformSampler sampler;
for (; idx < n; idx += blockDim.x * gridDim.x) {
int col_idx = idx % num_sampled_classes;
int row_idx = idx / num_sampled_classes;
if (col_idx < num_true) {
samples_data[idx] = label_data[row_idx * num_true + col_idx];
} else {
samples_data[idx] = samples_data[col_idx];
}
probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range);
probabilities_data[idx] =
gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries);
}
}
template <typename T>
int UniqSampler(const Sampler& sampler, const std::size_t num_samples,
int64_t* samples_data) {
// sample num_samles unique samples for an example, note that they are not
// all negative samples
std::unordered_set<int64_t> tmp_samples;
tmp_samples.clear();
int num_tries = 0;
int j = 0;
while (j < num_samples) {
++num_tries;
auto v = sampler.Sample();
auto insert_ok = tmp_samples.insert(v).second;
if (!insert_ok) {
continue;
}
samples_data[j] = v;
++j;
}
return num_tries;
}
/*
template <typename T>
void Print(Tensor & t, std::string name) {
if (!FLAGS_debug_print) {
return;
}
VLOG(1) << "qxz print "<< name;
VLOG(1) << name << "size = " << t.numel();
size_t size = t.numel();
type *d = t.data<type>();
#ifdef PADDLE_WITH_CUDA
std::vector<type> vec;
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
if (platform::is_gpu_place(t.place())) {
vec.resize(size);
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
d = vec.data();
}
#endif
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d);
std::string out;
for (size_t i = 0; i < size; i++) {
out += std::to_string(d[i]);
out += ",";
}
VLOG(1) << out;
}*/
template <typename T>
void GPUSampleWithProb<T>::operator()(
const platform::CUDADeviceContext& context, const int seed,
const int dict_size, const bool uniq, const std::size_t num_samples,
const Tensor* L, Tensor* S, Tensor* P) {
// UNDERSTAND: dimension issues
const auto lbl_dim = L->dims();
const int batch_size = lbl_dim[0];
const int num_true = lbl_dim[1];
const int num_sampled_classes = num_true + num_samples;
framework::DDim ret_dim{batch_size, num_sampled_classes};
// UNDERSTAND: raw data view
const int64_t* label_data = L->data<int64_t>();
int64_t* samples_data = S->data<int64_t>();
T* probabilities_data = P->data<T>();
int s_size = num_samples;
framework::DDim s_dim{s_size};
Tensor s;
int64_t* s_data = s.mutable_data<int64_t>(s_dim, platform::CPUPlace());
math::LogUniformSampler sampler(dict_size, seed);
int range = dict_size;
float log_range = log(range + 1);
int num_tries = UniqSampler<T>(sampler, num_samples, s_data);
VLOG(1) << "num_tries: " << num_tries;
PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data,
sizeof(int64_t) * num_samples,
cudaMemcpyHostToDevice));
int threads = 512;
const size_t size = batch_size * num_sampled_classes;
int grid = (batch_size * num_sampled_classes + threads - 1) / threads;
SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>(
size, num_tries, range, log_range, num_true, num_samples, label_data,
samples_data, probabilities_data);
}
template class GPUSampleWithProb<float>;
template class GPUSampleWithProb<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
/* UNDERSTAND: utility function to adjust probability for unique sampling,
return whatever as it is if not using unique samping */
template <typename T>
static T adjust_prob(const T prob, const int num_samples, const int num_tries) {
if (num_samples == num_tries) {
return prob * num_samples;
} else {
return -expm1(num_tries * log1p(-prob));
}
}
template <typename DeviceContext, typename T>
class SampleWithProb {
public:
void operator()(const DeviceContext& context, const Sampler& sampler,
const std::size_t num_samples, const Tensor* L, Tensor* S,
Tensor* P) {
// UNDERSTAND: dimension issues
const auto lbl_dim = L->dims();
const int batch_size = lbl_dim[0];
const int num_true = lbl_dim[1];
const int num_sampled_classes = num_true + num_samples;
framework::DDim ret_dim{batch_size, num_sampled_classes};
// UNDERSTAND: raw data view
const int64_t* label_data = L->data<int64_t>();
int64_t* samples_data =
S->mutable_data<int64_t>(ret_dim, context.GetPlace());
T* probabilities_data = P->mutable_data<T>(ret_dim, context.GetPlace());
// temp sets for unique sampling
std::unordered_set<int64_t> tmp_samples;
int j = 0; // column index
// add true labels, not that efficient
while (j < num_true) {
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + j;
auto v = label_data[i * num_true + j];
samples_data[samples_index] = v;
probabilities_data[samples_index] = sampler.Probability(v);
}
++j;
}
// sample num_samles unique samples for an example, note that they are not
// all negative samples
tmp_samples.clear();
int num_tries = 0;
while (j < num_sampled_classes) {
++num_tries;
auto v = sampler.Sample();
auto insert_ok = tmp_samples.insert(v).second;
if (!insert_ok) {
continue;
}
auto p = sampler.Probability(v);
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + j;
samples_data[samples_index] = v;
probabilities_data[samples_index] = p;
}
++j;
}
// compute Q(y|x), because of unique sampling, probabilities need to be
// adjusted
for (int k = 0; k < num_sampled_classes; ++k) {
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + k;
probabilities_data[samples_index] = adjust_prob(
probabilities_data[samples_index], num_samples, num_tries);
}
}
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class GPUSampleWithProb {
public:
void operator()(const platform::CUDADeviceContext& context, const int seed,
const int dict_size, const bool uniq,
const std::size_t num_samples, const Tensor* L, Tensor* S,
Tensor* P);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sample_logits_op.h"
#include "paddle/fluid/operators/math/sample_prob.h"
namespace paddle {
namespace operators {
class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.");
AddInput("Label",
"(Tensor) The ground truth which is a 2-D tensor. Label is a "
"Tensor<int64> with shape [N x NT], where NT is the number of"
"true labels for each example.");
AddInput(
"CustomSamples",
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shaoe [N x "
"S+NT]."
"The customized sample labels with true labels at first. This tensor"
"is only use_custom_samples is true.")
.AsDispensable();
AddInput(
"CustomProbabilities",
"(Tensor, default: Tensor<float>), A 2-D tensor with shaoe [N x S+NT]."
"The customized sample probabilities with true labels at first. This "
"tensor is only use_custom_samples is true.")
.AsDispensable();
AddOutput(
"Samples",
"(Tensor, default: Tensor<int64_t>), A 2-D tensor with shape [N x "
"S+NT]."
"The outputs value of sampler by given the true label, where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example. This will be used in"
"backward calculation.")
.AsIntermediate();
AddOutput(
"Probabilities",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x "
"S+NT]."
"The outputs value of progabilites of samples by given the true label, "
"where S is the "
"number of negative sample for each example. So Samples includes NT "
"true"
"labels and S negative labels for each example.")
.AsIntermediate();
AddOutput("SampledLogits",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"[N x S+NT]. The outputs value of sampled softmax, which will be"
"used in backward calculation.")
.AsIntermediate();
AddOutput("SampledLabel",
"(Tensor, default: Tensor<int64>), A 2-D tensor. The cross "
"entropy loss with shape [N x NT].");
AddAttr<bool>(
"use_custom_samples",
"An indicator whether to use custom samples with probabilities, if True"
"the operator will use custom samples and custom probabilities"
"otherwise, the operator will generate them by itself.")
.SetDefault(false);
AddAttr<bool>(
"uniq",
"An indicator whether to sample non-repetitive negtive labels, if True"
"the operator will sample negtive labels without replacement."
"otherwise, the operator will sample negtive labels with replacement.")
.SetDefault(false);
AddAttr<bool>(
"remove_accidental_hits",
"An indicator whether to remove accidental hits when samples hits true"
"labels, the removal is implemented by subtracting the corresponding"
"logits by float_max to subpress their softmax to be zero.")
.SetDefault(true);
AddAttr<int>("num_samples", "The number of negative samples.");
AddAttr<int>("seed", "Random seed for generating samples").SetDefault(0);
AddComment(R"DOC(
TODO(chenfeiyu): Write documentation for this Operator.
Sampled Softmax With Cross Entropy Operator.
Cross entropy loss with sampled softmax is used as the output layer extensively.
This operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
$$Loss_j = -\text{Logit}_{Label_j} +
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
j = 1,..., K$$
2) Soft label (each sample can have a distribution over all classes)
$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i -
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
j = 1,...,K$$
)DOC");
}
};
class SampleLogitsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Samples"),
"Output(Samples) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Probabilities"),
"Output(Probabilities) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"),
"Output(SampledLogits) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("SampledLabel"),
"Output(SampledLabel) should be not null.");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(
logits_dims.size(), 2UL,
"The logits of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
const int num_samples = ctx->Attrs().Get<int>("num_samples");
const int num_sampled_classes = labels_dims[1] + num_samples;
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLabel", {logits_dims[0], labels_dims[1]});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context());
// kt.place_ = platform::CPUPlace();
return kt;
}
};
// UNDERSTAND: InferShape for Grad
class SampleLogitsOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Samples"),
"Input(Samples) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
"Input(SampledLogits) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")),
"Input(SampledLogits@Grad) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null.");
auto logit_dims = ctx->GetInputDim("Logits");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL,
"The logits should be a 2-D tensor.");
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Logits"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("SampledLogits")));
framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context());
// kt.place_ = platform::CPUPlace();
return kt;
}
};
// UNDERSTAND: what's the rule for making a GradMaker TODO
class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* grad_op = new framework::OpDesc();
grad_op->SetType("sample_logits_grad");
grad_op->SetInput("Logits", Input("Logits"));
grad_op->SetInput("Label", Input("Label"));
grad_op->SetInput("Samples", Output("Samples"));
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
grad_op->SetInput(framework::GradVarName("SampledLogits"),
OutputGrad("SampledLogits"));
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sample_logits, ops::SampleLogitsOp, ops::SampleLogitsOpMaker,
ops::SampleLogitsGradMaker);
REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad);
REGISTER_OP_CPU_KERNEL(sample_logits, ops::SampleLogitsKernel<float>,
ops::SampleLogitsKernel<double>);
REGISTER_OP_CPU_KERNEL(sample_logits_grad, ops::SampleLogitsGradKernel<float>,
ops::SampleLogitsGradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/sample_logits_op.h"
namespace paddle {
namespace operators {
DEFINE_bool(debug_print, true, "run debug mode");
// UNDERSTAND: something like take_along_axis in numpy.
template <typename T>
__global__ void GPUTakeAlongD1(size_t size, const int batch_size,
const int array_slice_size,
const int idx_slice_size, const T* p_array,
const int64_t* p_index, T* p_value) {
const auto value_slice_size = idx_slice_size;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = blockDim.x * gridDim.x;
for (; idx < size; idx += step_size) {
int i = idx / idx_slice_size;
auto array_index = p_index[idx];
p_value[idx] = p_array[i * array_slice_size + array_index];
}
}
// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template <typename T>
__global__ void GPUPutAlongD1(size_t size, const int batch_size,
const int array_slice_size,
const int idx_slice_size, T* p_array,
const int64_t* p_index, const T* p_value) {
const auto value_slice_size = idx_slice_size;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = blockDim.x * gridDim.x;
// size == batch_size
for (; idx < size; idx += step_size) {
int i = idx;
for (int j = 0; j < idx_slice_size; ++j) {
auto array_index = p_index[i * idx_slice_size + j];
p_array[i * array_slice_size + array_index] +=
p_value[i * idx_slice_size + j];
}
}
}
// UNDERSTAND: set label as 0,1,...,num_true-1
template <typename T>
__global__ void GPUSetLabel(size_t size, const int num_true, int64_t* p_array) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = blockDim.x * gridDim.x;
for (; idx < size; idx += step_size) {
p_array[idx] = idx % num_true;
}
}
// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template <typename T>
__global__ void gpu_compute_remove_accidental_hits(const int size,
const int num_true,
const int idx_slice_size,
const int64_t* p_index,
T* p_value) {
const auto value_slice_size = idx_slice_size;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = blockDim.x * gridDim.x;
for (; idx < size; idx += step_size) {
int i = idx / idx_slice_size;
if (idx % idx_slice_size < num_true) continue;
for (int j = 0; j < num_true; ++j) {
const auto true_idx = i * idx_slice_size + j;
if (p_index[true_idx] == p_index[idx]) {
p_value[idx] -= 1e20;
break;
}
}
}
}
template <typename T>
class SampleLogitsCUDAKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
template <typename type>
void Print(const Tensor& t, std::string name) const {
if (!FLAGS_debug_print) {
return;
}
VLOG(1) << "qxz print " << name;
VLOG(1) << name << "size = " << t.numel();
size_t size = t.numel();
type* d = t.data<type>();
#ifdef PADDLE_WITH_CUDA
std::vector<type> vec;
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
if (platform::is_gpu_place(t.place())) {
vec.resize(size);
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
d = vec.data();
}
#endif
VLOG(1) << name << " data_ptr = " << static_cast<void*>(d);
std::string out;
for (size_t i = 0; i < size; i++) {
out += std::to_string(d[i]);
out += ",";
}
VLOG(1) << out;
}
void Compute(const framework::ExecutionContext& context) const override {
// get necessary inputs
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* label = context.Input<Tensor>("Label");
VLOG(3) << "Enter SampleLogitsCUDAKernel";
// get necessary outputs
Tensor* samples = context.Output<Tensor>("Samples");
Tensor* probabilities = context.Output<Tensor>("Probabilities");
Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
Tensor* sampled_label = context.Output<Tensor>("SampledLabel");
// shapes
const auto batch_size = logits->dims()[0];
const auto num_classes = logits->dims()[1];
const auto label_dim = label->dims();
const auto num_true = label_dim[1];
const auto samples_dim = samples->dims();
// attrs
const auto num_samples = context.Attr<int>("num_samples");
const bool use_custom_samples = context.Attr<bool>("use_custom_samples");
const bool uniq = context.Attr<bool>("uniq");
const bool remove_accidental_hits =
context.Attr<bool>("remove_accidental_hits");
// device contexts
auto& dev_ctx = context.cuda_device_context();
// UNDERSTAND: allocate memories for temporaries
sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, sampled_logits, static_cast<T>(0));
auto sampled_label_data =
sampled_label->mutable_data<int64_t>(label_dim, context.GetPlace());
int threads = 512;
size_t size = batch_size * num_true;
int grid = (size + threads - 1) / threads;
GPUSetLabel<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, num_true, sampled_label_data);
if (use_custom_samples) {
const Tensor* custom_samples = context.Input<Tensor>("CustomSamples");
const Tensor* custom_probabilities =
context.Input<Tensor>("CustomProbabilities");
samples->ShareDataWith(*custom_samples);
probabilities->ShareDataWith(*custom_probabilities);
} else {
samples->mutable_data<int64_t>(context.GetPlace());
probabilities->mutable_data<T>(samples_dim, context.GetPlace());
// UNDERSTAND: sampling
const auto seed = context.Attr<int>("seed");
auto sampler_with_prob = math::GPUSampleWithProb<T>();
Print<int64_t>(*samples, std::string("samples1"));
sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq,
num_samples, label, samples, probabilities);
}
Print<int64_t>(*samples, std::string("samples2"));
Print<T>(*probabilities, std::string("probabilities"));
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
const auto num_take = samples->dims()[1];
const auto array_dims = logits->dims();
const auto idx_dims = samples->dims();
const T* p_array = logits->data<T>();
const int64_t* p_index = samples->data<int64_t>();
T* p_value = sampled_logits->data<T>();
// src slice size
const auto array_slice_size = array_dims[1];
// index slice size
const auto idx_slice_size = idx_dims[1];
size = batch_size * num_take;
grid = (size + threads - 1) / threads;
GPUTakeAlongD1<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
p_value);
Print<T>(*sampled_logits, std::string("sampled_logits"));
if (remove_accidental_hits) {
const size_t size = batch_size * (num_true + num_samples);
int grid = (size + threads - 1) / threads;
gpu_compute_remove_accidental_hits<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, num_true, idx_slice_size, p_index, p_value);
Print<T>(*sampled_logits,
std::string("sampled_logits_remove_accidental_hits"));
}
// subtracted sampled logits with logQ(y|x)
auto probs = EigenMatrix<T>::From(*probabilities);
auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
smp_logits.device(*dev_ctx.eigen_device()) =
(smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
.unaryExpr(TolerableValue<T>());
Print<T>(*sampled_logits, std::string("sampled_logits_res"));
}
};
template <typename T>
class SampleLogitsGradCUDAKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
template <typename type>
void Print(const Tensor& t, std::string name) const {
if (!FLAGS_debug_print) {
return;
}
VLOG(1) << "qxz print " << name;
VLOG(1) << name << "size = " << t.numel();
size_t size = t.numel();
const type* d = t.data<type>();
#ifdef PADDLE_WITH_CUDA
std::vector<type> vec;
platform::DeviceContextPool::Instance().Get(t.place())->Wait();
if (platform::is_gpu_place(t.place())) {
vec.resize(size);
cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost);
d = vec.data();
}
#endif
VLOG(1) << name << " data_ptr = " << static_cast<const void*>(d);
std::string out;
for (size_t i = 0; i < size; i++) {
out += std::to_string(d[i]);
out += ",";
}
VLOG(1) << out;
}
void Compute(const framework::ExecutionContext& context) const override {
auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* samples = context.Input<Tensor>("Samples");
const Tensor* sampled_logits_grad =
context.Input<Tensor>(framework::GradVarName("SampledLogits"));
logits_grad->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.cuda_device_context();
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, logits_grad, static_cast<T>(0));
// UNDERSTAND: scatter it back to logit_grad
const auto batch_size = samples->dims()[0];
const auto num_put = samples->dims()[1];
const auto array_dims = logits_grad->dims();
const auto idx_dims = samples->dims();
T* p_array = logits_grad->data<T>();
const int64_t* p_index = samples->data<int64_t>();
const T* p_value = sampled_logits_grad->data<T>();
// src slice size
const auto array_slice_size = array_dims[1];
// index slice size
const auto idx_slice_size = idx_dims[1];
int threads = 128;
const size_t size = batch_size;
int grid = (size + threads - 1) / threads;
Print<T>(*sampled_logits_grad, std::string("sampled_logits_grad"));
Print<int64_t>(*samples, std::string("samples"));
GPUPutAlongD1<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, batch_size, array_slice_size, idx_slice_size, p_array, p_index,
p_value);
Print<T>(*logits_grad, std::string("logits_grad"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sample_logits, ops::SampleLogitsCUDAKernel<float>,
ops::SampleLogitsCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(sample_logits_grad,
ops::SampleLogitsGradCUDAKernel<float>,
ops::SampleLogitsGradCUDAKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
return x;
}
};
// UNDERSTAND: something like take_along_axis in numpy.
template <typename T>
static void CPUTakeAlongD1(const platform::DeviceContext& ctx,
const framework::Tensor& array,
const framework::Tensor& index,
framework::Tensor* value) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
PADDLE_ENFORCE(index.dims().size() == 2 && array.dims().size() == 2 &&
index.dims()[0] == array.dims()[0] &&
index.dims() == value->dims());
const auto batch_size = index.dims()[0];
const auto num_take = index.dims()[1];
const auto array_dims = array.dims();
const auto idx_dims = index.dims();
// UNDERSTAND: no allocations here
const T* p_array = array.data<T>();
const int64_t* p_index = index.data<int64_t>();
T* p_value = value->data<T>();
// src slice size
const auto array_slice_size = array_dims[1];
// index slice size
const auto idx_slice_size = idx_dims[1];
const auto value_slice_size = idx_slice_size;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_take; ++j) {
auto array_index = p_index[i * idx_slice_size + j];
p_value[i * value_slice_size + j] =
p_array[i * array_slice_size + array_index];
}
}
}
// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate
// indices, scatter is done in += way.
template <typename T>
static void CPUPutAlongD1(const platform::DeviceContext& ctx,
framework::Tensor* array,
const framework::Tensor& index,
const framework::Tensor& value) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K)
PADDLE_ENFORCE(index.dims().size() == 2 && array->dims().size() == 2 &&
index.dims()[0] == array->dims()[0] &&
index.dims() == value.dims());
const auto batch_size = index.dims()[0];
const auto num_put = index.dims()[1];
auto array_dims = array->dims();
auto idx_dims = index.dims();
// UNDERSTAND: no allocations here
T* p_array = array->data<T>();
const int64_t* p_index = index.data<int64_t>();
const T* p_value = value.data<T>();
// slice sizes
const auto array_slice_size = array_dims[1];
const auto idx_slice_size = idx_dims[1];
const auto value_slice_size = idx_slice_size;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_put; ++j) {
auto array_index = p_index[i * idx_slice_size + j];
p_array[i * array_slice_size + array_index] +=
p_value[i * value_slice_size + j];
}
}
}
// UNDERSTAND: compute accidentdal hits from samples and minus corresponding
// logits by a float max, here 1e20
template <typename T>
static void compute_remove_accidental_hits(const platform::DeviceContext& ctx,
framework::Tensor* sampled_logits,
const framework::Tensor& samples,
const int num_true) {
const auto batch_size = sampled_logits->dims()[0];
const auto num_sampled_classes = sampled_logits->dims()[1];
T* sampled_logits_data = sampled_logits->data<T>();
const auto samples_data = samples.data<int64_t>();
std::unordered_set<int64_t> tmp_true_labels;
for (int i = 0; i < batch_size; ++i) {
tmp_true_labels.clear();
tmp_true_labels.insert(samples_data + i * num_sampled_classes,
samples_data + i * num_sampled_classes + num_true);
for (int j = num_true; j < num_sampled_classes; ++j) {
const auto idx = i * num_sampled_classes + j;
if (tmp_true_labels.find(samples_data[idx]) != tmp_true_labels.end())
sampled_logits_data[idx] -= 1e20;
}
}
}
template <typename T>
class SampleLogitsKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
"This kernel only runs on CPU.");
VLOG(3) << "Enter SampleLogitsKernel";
// get necessary inputs
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* label = context.Input<Tensor>("Label");
// get necessary outputs
Tensor* samples = context.Output<Tensor>("Samples");
Tensor* probabilities = context.Output<Tensor>("Probabilities");
Tensor* sampled_logits = context.Output<Tensor>("SampledLogits");
Tensor* sampled_label = context.Output<Tensor>("SampledLabel");
// shapes
const auto batch_size = logits->dims()[0];
const auto num_classes = logits->dims()[1];
const auto label_dim = label->dims();
const auto num_true = label_dim[1];
const auto samples_dim = samples->dims();
// attrs
const auto num_samples = context.Attr<int>("num_samples");
const bool use_custom_samples = context.Attr<bool>("use_custom_samples");
const bool remove_accidental_hits =
context.Attr<bool>("remove_accidental_hits");
// device contexts
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
// UNDERSTAND: allocate memories for temporaries
sampled_logits->mutable_data<T>(samples_dim, context.GetPlace());
auto sampled_label_data =
sampled_label->mutable_data<int64_t>(label_dim, context.GetPlace());
for (int i = 0; i < batch_size; ++i)
for (int j = 0; j < num_true; ++j)
sampled_label_data[i * num_true + j] = j;
if (use_custom_samples) {
const Tensor* custom_samples = context.Input<Tensor>("CustomSamples");
const Tensor* custom_probabilities =
context.Input<Tensor>("CustomProbabilities");
samples->ShareDataWith(*custom_samples);
probabilities->ShareDataWith(*custom_probabilities);
} else {
samples->mutable_data<int64_t>(context.GetPlace());
probabilities->mutable_data<T>(samples_dim, context.GetPlace());
// UNDERSTAND: sampling
const auto seed = context.Attr<int>("seed");
auto sampler_with_prob =
math::SampleWithProb<platform::CPUDeviceContext, T>();
sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed),
num_samples, label, samples, probabilities);
}
// UNDERSTAND: gather sampled logits and remove accidental hits if needed
CPUTakeAlongD1<T>(dev_ctx, *logits, *samples, sampled_logits);
if (remove_accidental_hits) {
compute_remove_accidental_hits<T>(dev_ctx, sampled_logits, *samples,
num_true);
}
/* Debug
const auto num_sampled_classes = samples_dim[1];
std::cout << "Sampled Logits" << std::endl;
const auto sampled_logits_data = sampled_logits->data<T>();
for (int i = 0; i < sampled_logits->numel(); ++i) {
std::cout << sampled_logits_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
/* Debug
std::cout << "Samples" << std::endl;
const auto samples_data = samples->data<int64_t>();
for (int i = 0; i < samples->numel(); ++i) {
std::cout << samples_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
/* Debug
std::cout << "Probabilities" << std::endl;
const auto probabilities_data = probabilities->data<T>();
for (int i = 0; i < probabilities->numel(); ++i) {
std::cout << probabilities_data[i] << ", ";
if ((i + 1) % num_sampled_classes == 0)
std::cout << std::endl;
}
std::cout << std::endl;
*/
// subtracted sampled logits with logQ(y|x)
auto probs = EigenMatrix<T>::From(*probabilities);
auto smp_logits = EigenMatrix<T>::From(*sampled_logits);
smp_logits.device(*dev_ctx.eigen_device()) =
(smp_logits - probs.log().unaryExpr(TolerableValue<T>()))
.unaryExpr(TolerableValue<T>());
}
};
template <typename T>
class SampleLogitsGradKernel : public framework::OpKernel<T> {
public:
using Tensor = framework::Tensor;
void Compute(const framework::ExecutionContext& context) const override {
auto logits_grad = context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* samples = context.Input<Tensor>("Samples");
const Tensor* sampled_logits_grad =
context.Input<Tensor>(framework::GradVarName("SampledLogits"));
logits_grad->mutable_data<T>(context.GetPlace());
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
set_zero(dev_ctx, logits_grad, static_cast<T>(0));
// const bool remove_accidental_hits =
// context.Attr<bool>("remove_accidental_hits");
// UNDERSTAND: scatter it back to logit_grad
CPUPutAlongD1<T>(dev_ctx, logits_grad, *samples, *sampled_logits_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -131,7 +131,7 @@ def __bootstrap__():
'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'inner_op_parallelism', 'enable_parallel_graph'
'inner_op_parallelism', 'enable_parallel_graph', 'debug_print'
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
......
......@@ -87,6 +87,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'sample_logits',
'hsigmoid',
'beam_search',
'row_conv',
......@@ -5764,6 +5765,104 @@ def softmax_with_cross_entropy(logits,
return loss
def sample_logits(logits,
label,
num_samples,
uniq=True,
remove_accidental_hits=True,
use_custom_samples=False,
custom_samples=None,
custom_probabilities=None,
seed=0):
"""
**Sampled Softmax With Cross Entropy Operator.**
Cross entropy loss with sampled softmax is used as the output layer for
larger output classes extensively. This operator samples a number of samples
for each example(row), and computes the softmax normalized values for each
row of the sampled tensor, after which cross-entropy loss is computed.
This provides a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
For examples with T true labels (T >= 1), we assume that each true label has
a probability of 1/T. For each sample, S samples are generated using a
log uniform distribution. True labels are concatenated with hese samples to
form T + S samples for each example. So, assume the shape of logits is
[N x K], the shape for samples is [N x (T+S)]. For each sampled label, a
probability is calculated, which corresponds to the Q(y|x) in
[Jean et al., 2014](http://arxiv.org/abs/1412.2007).
Logits are sampled according to the sampled labels. Then if
remove_accidental_hits is True, if a sample[i, j] accidentally hits true
labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to
make its softmax result close to zero. Then samled logits are subtracted by
logQ(y|x), these sampled logits and re-indexed labels are used to compute
a softmax with cross entropy.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. Label is a
Tensor<int64> with shape [N x T], where T is the number of true
labels per example.
num_samples (int): The number for each example, num_samples should be
less than the number of class.
seed (int): The random seed for generating random number, which is used
in the process of sampling. Default is 0.
remove_accidental_hits (bool): A flag indicating whether to remove
accidental hits when sampling. If True and if a sample[i, j]
accidentally hits true labels, then the corresponding
sampled_logits[i, j] is minus by 1e20 to make its softmax result
close to zero. Default is True.
Returns:
Variable: Return the cross entropy loss which is a 2-D tensor with shape
[N x 1].
Examples:
.. code-block:: python
logits = fluid.layers.data(name='data', shape=[256], dtype='float32')
label = fluid.layers.data(name='label', shape=[5], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.sampled_softmax_with_cross_entropy(
logits=fc, label=label, num_samples=25)
"""
helper = LayerHelper('sample_logits', **locals())
samples = helper.create_variable_for_type_inference(dtype='int64')
probabilities = helper.create_variable_for_type_inference(
dtype=logits.dtype)
sampled_logits \
= helper.create_variable_for_type_inference(dtype=logits.dtype)
sampled_label = helper.create_variable_for_type_inference(dtype='int64')
helper.append_op(
type='sample_logits',
inputs={
'Logits': logits,
'Label': label,
'CustomSamples': custom_samples,
'CustomProbabilities': custom_probabilities
},
outputs={
'Samples': samples,
'Probabilities': probabilities,
'SampledLabel': sampled_label,
'SampledLogits': sampled_logits
},
attrs={
'use_custom_samples': use_custom_samples,
'uniq': uniq,
'remove_accidental_hits': remove_accidental_hits,
'num_samples': num_samples,
'seed': seed
})
return sampled_logits, sampled_label, samples, probabilities
def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
"""
This layer computes the smooth L1 loss for Variable :attr:`x` and :attr:`y`.
......
......@@ -350,6 +350,7 @@ class OpTest(unittest.TestCase):
actual_t = np.array(actual)
expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
#import pdb; pdb.set_trace()
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
......
......@@ -374,6 +374,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_sample_logits(self):
program = Program()
with program_guard(program):
logits = layers.data(name='Logits', shape=[256], dtype='float64')
label = layers.data(name='Label', shape=[5], dtype='int64')
num_samples = 25
output = layers.sample_logits(logits, label, num_samples)
self.assertIsNotNone(output)
print(str(program))
@decorators.prog_scope()
def test_nce(self):
window_size = 5
......
......@@ -156,8 +156,26 @@ def append_input_output(block, op_proto, np_list, is_input, dtype):
return var_dict
def var_cast(block, input):
if input.dtype == core.VarDesc.VarType.FP32 or input.dtype == core.VarDesc.VarType.FP32:
return input
out = block.create_var(dtype="float32", shape=[1])
op = block.append_op(
inputs={"X": input},
outputs={"Out": out},
type='cast',
attrs={
'out_dtype': core.VarDesc.VarType.FP32,
'in_dtype': input.dtype
})
op.desc.infer_var_type(block.desc)
op.desc.infer_shape(block.desc)
return out
def append_loss_ops(block, output_names):
mean_inputs = list(map(block.var, output_names))
mean_inputs = [var_cast(block, x) for x in mean_inputs]
if len(mean_inputs) == 1:
loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册