From 58ad40cc15104757fc270d127e2be76a9e6bc999 Mon Sep 17 00:00:00 2001 From: xuezhong Date: Wed, 30 Jan 2019 14:04:44 +0000 Subject: [PATCH] add sample_logits op --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 1 + paddle/fluid/operators/math/sample_prob.cc | 26 + paddle/fluid/operators/math/sample_prob.cu | 188 +++ paddle/fluid/operators/math/sample_prob.h | 118 ++ paddle/fluid/operators/sample_logits_op.cc | 248 ++++ paddle/fluid/operators/sample_logits_op.cu | 321 +++++ paddle/fluid/operators/sample_logits_op.h | 275 ++++ python/paddle/fluid/__init__.py | 2 +- python/paddle/fluid/layers/nn.py | 99 ++ .../paddle/fluid/tests/unittests/op_test.py | 1 + .../fluid/tests/unittests/test_layers.py | 10 + .../tests/unittests/test_sample_logits.py | 1233 +++++++++++++++++ .../paddle/fluid/tests/unittests/testsuite.py | 18 + 14 files changed, 2540 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/math/sample_prob.cc create mode 100644 paddle/fluid/operators/math/sample_prob.cu create mode 100644 paddle/fluid/operators/math/sample_prob.h create mode 100644 paddle/fluid/operators/sample_logits_op.cc create mode 100644 paddle/fluid/operators/sample_logits_op.cu create mode 100644 paddle/fluid/operators/sample_logits_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_sample_logits.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e099425b9..52e85789c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index e20524012..5c44d044c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -39,6 +39,7 @@ math_library(cross_entropy) math_library(cos_sim_functor) math_library(depthwise_conv) math_library(im2col) +math_library(sample_prob) math_library(sampler) math_library(gru_compute DEPS activation_functions math_function) diff --git a/paddle/fluid/operators/math/sample_prob.cc b/paddle/fluid/operators/math/sample_prob.cc new file mode 100644 index 000000000..1a1751d01 --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/sample_prob.h" + +namespace paddle { +namespace operators { +namespace math { + +template class SampleWithProb; +template class SampleWithProb; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.cu b/paddle/fluid/operators/math/sample_prob.cu new file mode 100644 index 000000000..01c61fd80 --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.cu @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/sampler.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; + +template +__device__ T gpu_adjust_prob(const T prob, const int num_samples, + const int num_tries) { + if (num_samples == num_tries) { + return prob * num_samples; + } else { + return -expm1(num_tries * log1p(-prob)); + } +} + +class GPULogUniformSampler { + public: + __device__ int64_t Sample(float random, const int range, + const float log_range) const; + __device__ float Probability(int64_t value, const float log_range) const; +}; + +__device__ int64_t GPULogUniformSampler::Sample(float random, const int range, + const float log_range) const { + // Got Log Uniform distribution from uniform distribution by + // inverse_transform_sampling method + const int64_t value = static_cast(exp(random * log_range)) - 1; + // Mathematically, value should be <= range_, but might not be due to some + // floating point roundoff, so we mod by range_. + return value % range; +} + +__device__ float GPULogUniformSampler::Probability( + int64_t value, const float log_range) const { + // Given f(x) = 1/[(x+1) * log_range_] + // The value's probability is integral of f(x) from value to (value + 1) + return (log((value + 2.0) / (value + 1.0))) / log_range; +} + +template +__global__ void SamplingCondidate( + const size_t n, const int num_tries, const int range, const float log_range, + const int num_true, const std::size_t num_samples, + const int64_t* label_data, int64_t* samples_data, T* probabilities_data) { + const int num_sampled_classes = num_true + num_samples; + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = 0; + GPULogUniformSampler sampler; + + for (; idx < n; idx += blockDim.x * gridDim.x) { + int col_idx = idx % num_sampled_classes; + int row_idx = idx / num_sampled_classes; + if (col_idx < num_true) { + samples_data[idx] = label_data[row_idx * num_true + col_idx]; + } else { + samples_data[idx] = samples_data[col_idx]; + } + probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range); + probabilities_data[idx] = + gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries); + } +} + +template +int UniqSampler(const Sampler& sampler, const std::size_t num_samples, + int64_t* samples_data) { + // sample num_samles unique samples for an example, note that they are not + // all negative samples + std::unordered_set tmp_samples; + tmp_samples.clear(); + int num_tries = 0; + int j = 0; + while (j < num_samples) { + ++num_tries; + auto v = sampler.Sample(); + auto insert_ok = tmp_samples.insert(v).second; + if (!insert_ok) { + continue; + } + samples_data[j] = v; + ++j; + } + return num_tries; +} +/* +template +void Print(Tensor & t, std::string name) { + if (!FLAGS_debug_print) { + return; + } + VLOG(1) << "qxz print "<< name; + VLOG(1) << name << "size = " << t.numel(); + size_t size = t.numel(); + type *d = t.data(); +#ifdef PADDLE_WITH_CUDA + std::vector vec; + platform::DeviceContextPool::Instance().Get(t.place())->Wait(); + if (platform::is_gpu_place(t.place())) { + vec.resize(size); + cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost); + d = vec.data(); + } +#endif + VLOG(1) << name << " data_ptr = " << static_cast(d); + std::string out; + for (size_t i = 0; i < size; i++) { + out += std::to_string(d[i]); + out += ","; + } + VLOG(1) << out; +}*/ + +template +void GPUSampleWithProb::operator()( + const platform::CUDADeviceContext& context, const int seed, + const int dict_size, const bool uniq, const std::size_t num_samples, + const Tensor* L, Tensor* S, Tensor* P) { + // UNDERSTAND: dimension issues + const auto lbl_dim = L->dims(); + const int batch_size = lbl_dim[0]; + const int num_true = lbl_dim[1]; + const int num_sampled_classes = num_true + num_samples; + framework::DDim ret_dim{batch_size, num_sampled_classes}; + + // UNDERSTAND: raw data view + const int64_t* label_data = L->data(); + int64_t* samples_data = S->data(); + T* probabilities_data = P->data(); + + int s_size = num_samples; + framework::DDim s_dim{s_size}; + Tensor s; + int64_t* s_data = s.mutable_data(s_dim, platform::CPUPlace()); + + math::LogUniformSampler sampler(dict_size, seed); + + int range = dict_size; + float log_range = log(range + 1); + + int num_tries = UniqSampler(sampler, num_samples, s_data); + VLOG(1) << "num_tries: " << num_tries; + PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data, + sizeof(int64_t) * num_samples, + cudaMemcpyHostToDevice)); + + int threads = 512; + const size_t size = batch_size * num_sampled_classes; + int grid = (batch_size * num_sampled_classes + threads - 1) / threads; + SamplingCondidate<<>>( + size, num_tries, range, log_range, num_true, num_samples, label_data, + samples_data, probabilities_data); +} + +template class GPUSampleWithProb; +template class GPUSampleWithProb; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/sample_prob.h b/paddle/fluid/operators/math/sample_prob.h new file mode 100644 index 000000000..58d21c63f --- /dev/null +++ b/paddle/fluid/operators/math/sample_prob.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/sampler.h" + +namespace paddle { +namespace operators { +namespace math { + +using Tensor = framework::Tensor; + +/* UNDERSTAND: utility function to adjust probability for unique sampling, +return whatever as it is if not using unique samping */ +template +static T adjust_prob(const T prob, const int num_samples, const int num_tries) { + if (num_samples == num_tries) { + return prob * num_samples; + } else { + return -expm1(num_tries * log1p(-prob)); + } +} + +template +class SampleWithProb { + public: + void operator()(const DeviceContext& context, const Sampler& sampler, + const std::size_t num_samples, const Tensor* L, Tensor* S, + Tensor* P) { + // UNDERSTAND: dimension issues + const auto lbl_dim = L->dims(); + const int batch_size = lbl_dim[0]; + const int num_true = lbl_dim[1]; + const int num_sampled_classes = num_true + num_samples; + framework::DDim ret_dim{batch_size, num_sampled_classes}; + + // UNDERSTAND: raw data view + const int64_t* label_data = L->data(); + int64_t* samples_data = + S->mutable_data(ret_dim, context.GetPlace()); + T* probabilities_data = P->mutable_data(ret_dim, context.GetPlace()); + + // temp sets for unique sampling + std::unordered_set tmp_samples; + int j = 0; // column index + // add true labels, not that efficient + while (j < num_true) { + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + j; + auto v = label_data[i * num_true + j]; + samples_data[samples_index] = v; + probabilities_data[samples_index] = sampler.Probability(v); + } + ++j; + } + + // sample num_samles unique samples for an example, note that they are not + // all negative samples + tmp_samples.clear(); + int num_tries = 0; + while (j < num_sampled_classes) { + ++num_tries; + auto v = sampler.Sample(); + auto insert_ok = tmp_samples.insert(v).second; + if (!insert_ok) { + continue; + } + auto p = sampler.Probability(v); + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + j; + samples_data[samples_index] = v; + probabilities_data[samples_index] = p; + } + ++j; + } + + // compute Q(y|x), because of unique sampling, probabilities need to be + // adjusted + for (int k = 0; k < num_sampled_classes; ++k) { + for (int i = 0; i < batch_size; ++i) { + auto samples_index = i * num_sampled_classes + k; + probabilities_data[samples_index] = adjust_prob( + probabilities_data[samples_index], num_samples, num_tries); + } + } + } +}; + +#ifdef PADDLE_WITH_CUDA +template +class GPUSampleWithProb { + public: + void operator()(const platform::CUDADeviceContext& context, const int seed, + const int dict_size, const bool uniq, + const std::size_t num_samples, const Tensor* L, Tensor* S, + Tensor* P); +}; +#endif +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc new file mode 100644 index 000000000..160eb066e --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -0,0 +1,248 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sample_logits_op.h" +#include "paddle/fluid/operators/math/sample_prob.h" + +namespace paddle { +namespace operators { + +class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Logits", + "(Tensor, default: Tensor), The unscaled log probabilities " + "which is a 2-D tensor with shape [N x K]. N is the batch_size, " + "and K is the class number."); + AddInput("Label", + "(Tensor) The ground truth which is a 2-D tensor. Label is a " + "Tensor with shape [N x NT], where NT is the number of" + "true labels for each example."); + AddInput( + "CustomSamples", + "(Tensor, default: Tensor), A 2-D tensor with shaoe [N x " + "S+NT]." + "The customized sample labels with true labels at first. This tensor" + "is only use_custom_samples is true.") + .AsDispensable(); + AddInput( + "CustomProbabilities", + "(Tensor, default: Tensor), A 2-D tensor with shaoe [N x S+NT]." + "The customized sample probabilities with true labels at first. This " + "tensor is only use_custom_samples is true.") + .AsDispensable(); + AddOutput( + "Samples", + "(Tensor, default: Tensor), A 2-D tensor with shape [N x " + "S+NT]." + "The outputs value of sampler by given the true label, where S is the " + "number of negative sample for each example. So Samples includes NT " + "true" + "labels and S negative labels for each example. This will be used in" + "backward calculation.") + .AsIntermediate(); + AddOutput( + "Probabilities", + "(Tensor, default: Tensor), A 2-D tensor with shape [N x " + "S+NT]." + "The outputs value of progabilites of samples by given the true label, " + "where S is the " + "number of negative sample for each example. So Samples includes NT " + "true" + "labels and S negative labels for each example.") + .AsIntermediate(); + AddOutput("SampledLogits", + "(Tensor, default: Tensor), A 2-D tensor with shape" + "[N x S+NT]. The outputs value of sampled softmax, which will be" + "used in backward calculation.") + .AsIntermediate(); + AddOutput("SampledLabel", + "(Tensor, default: Tensor), A 2-D tensor. The cross " + "entropy loss with shape [N x NT]."); + AddAttr( + "use_custom_samples", + "An indicator whether to use custom samples with probabilities, if True" + "the operator will use custom samples and custom probabilities" + "otherwise, the operator will generate them by itself.") + .SetDefault(false); + AddAttr( + "uniq", + "An indicator whether to sample non-repetitive negtive labels, if True" + "the operator will sample negtive labels without replacement." + "otherwise, the operator will sample negtive labels with replacement.") + .SetDefault(false); + AddAttr( + "remove_accidental_hits", + "An indicator whether to remove accidental hits when samples hits true" + "labels, the removal is implemented by subtracting the corresponding" + "logits by float_max to subpress their softmax to be zero.") + .SetDefault(true); + AddAttr("num_samples", "The number of negative samples."); + AddAttr("seed", "Random seed for generating samples").SetDefault(0); + + AddComment(R"DOC( +TODO(chenfeiyu): Write documentation for this Operator. +Sampled Softmax With Cross Entropy Operator. + +Cross entropy loss with sampled softmax is used as the output layer extensively. +This operator computes the softmax normalized values for each row of the input +tensor, after which cross-entropy loss is computed. This provides a more +numerically stable gradient. + +Because this operator performs a softmax on logits internally, it expects +unscaled logits. This operator should not be used with the output of +softmax operator since that would produce incorrect results. + +When the attribute soft_label is set false, this operators expects mutually +exclusive hard labels, each sample in a batch is in exactly one class with a +probability of 1.0. Each sample in the batch will have a single label. + +The equation is as follows: + +1) Hard label (one-hot label, so every sample has exactly one class) + +$$Loss_j = -\text{Logit}_{Label_j} + +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), +j = 1,..., K$$ + +2) Soft label (each sample can have a distribution over all classes) + +$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i - +\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), +j = 1,...,K$$ + +)DOC"); + } +}; + +class SampleLogitsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Samples"), + "Output(Samples) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Probabilities"), + "Output(Probabilities) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("SampledLogits"), + "Output(SampledLogits) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("SampledLabel"), + "Output(SampledLabel) should be not null."); + + auto logits_dims = ctx->GetInputDim("Logits"); + auto labels_dims = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ( + logits_dims.size(), 2UL, + "The logits of softmax_with_cross_entropy should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL, + "The labels should be a 2-D tensor."); + + const int num_samples = ctx->Attrs().Get("num_samples"); + const int num_sampled_classes = labels_dims[1] + num_samples; + ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); + ctx->SetOutputDim("SampledLabel", {logits_dims[0], labels_dims[1]}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); + framework::OpKernelType kt = + framework::OpKernelType(data_type, ctx.device_context()); + // kt.place_ = platform::CPUPlace(); + return kt; + } +}; + +// UNDERSTAND: InferShape for Grad +class SampleLogitsOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Samples"), + "Input(Samples) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("SampledLogits"), + "Input(SampledLogits) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")), + "Input(SampledLogits@Grad) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), + "Output(Logits@Grad) should be not null."); + + auto logit_dims = ctx->GetInputDim("Logits"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "The label should be a 2-D tensor."); + PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL, + "The logits should be a 2-D tensor."); + + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Logits")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar( + ctx.InputVar(framework::GradVarName("SampledLogits"))); + framework::OpKernelType kt = + framework::OpKernelType(data_type, ctx.device_context()); + // kt.place_ = platform::CPUPlace(); + return kt; + } +}; + +// UNDERSTAND: what's the rule for making a GradMaker TODO +class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDesc(); + grad_op->SetType("sample_logits_grad"); + grad_op->SetInput("Logits", Input("Logits")); + grad_op->SetInput("Label", Input("Label")); + grad_op->SetInput("Samples", Output("Samples")); + grad_op->SetInput("SampledLogits", Output("SampledLogits")); + grad_op->SetInput(framework::GradVarName("SampledLogits"), + OutputGrad("SampledLogits")); + grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sample_logits, ops::SampleLogitsOp, ops::SampleLogitsOpMaker, + ops::SampleLogitsGradMaker); +REGISTER_OPERATOR(sample_logits_grad, ops::SampleLogitsOpGrad); +REGISTER_OP_CPU_KERNEL(sample_logits, ops::SampleLogitsKernel, + ops::SampleLogitsKernel); +REGISTER_OP_CPU_KERNEL(sample_logits_grad, ops::SampleLogitsGradKernel, + ops::SampleLogitsGradKernel); diff --git a/paddle/fluid/operators/sample_logits_op.cu b/paddle/fluid/operators/sample_logits_op.cu new file mode 100644 index 000000000..5b311bb67 --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.cu @@ -0,0 +1,321 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/sample_logits_op.h" + +namespace paddle { +namespace operators { + +DEFINE_bool(debug_print, true, "run debug mode"); + +// UNDERSTAND: something like take_along_axis in numpy. +template +__global__ void GPUTakeAlongD1(size_t size, const int batch_size, + const int array_slice_size, + const int idx_slice_size, const T* p_array, + const int64_t* p_index, T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + int i = idx / idx_slice_size; + auto array_index = p_index[idx]; + p_value[idx] = p_array[i * array_slice_size + array_index]; + } +} + +// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate +// indices, scatter is done in += way. +template +__global__ void GPUPutAlongD1(size_t size, const int batch_size, + const int array_slice_size, + const int idx_slice_size, T* p_array, + const int64_t* p_index, const T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + // size == batch_size + for (; idx < size; idx += step_size) { + int i = idx; + for (int j = 0; j < idx_slice_size; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_array[i * array_slice_size + array_index] += + p_value[i * idx_slice_size + j]; + } + } +} + +// UNDERSTAND: set label as 0,1,...,num_true-1 +template +__global__ void GPUSetLabel(size_t size, const int num_true, int64_t* p_array) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + p_array[idx] = idx % num_true; + } +} + +// UNDERSTAND: compute accidentdal hits from samples and minus corresponding +// logits by a float max, here 1e20 +template +__global__ void gpu_compute_remove_accidental_hits(const int size, + const int num_true, + const int idx_slice_size, + const int64_t* p_index, + T* p_value) { + const auto value_slice_size = idx_slice_size; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = blockDim.x * gridDim.x; + + for (; idx < size; idx += step_size) { + int i = idx / idx_slice_size; + if (idx % idx_slice_size < num_true) continue; + for (int j = 0; j < num_true; ++j) { + const auto true_idx = i * idx_slice_size + j; + if (p_index[true_idx] == p_index[idx]) { + p_value[idx] -= 1e20; + break; + } + } + } +} + +template +class SampleLogitsCUDAKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + template + void Print(const Tensor& t, std::string name) const { + if (!FLAGS_debug_print) { + return; + } + VLOG(1) << "qxz print " << name; + VLOG(1) << name << "size = " << t.numel(); + size_t size = t.numel(); + type* d = t.data(); +#ifdef PADDLE_WITH_CUDA + std::vector vec; + platform::DeviceContextPool::Instance().Get(t.place())->Wait(); + if (platform::is_gpu_place(t.place())) { + vec.resize(size); + cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost); + d = vec.data(); + } +#endif + VLOG(1) << name << " data_ptr = " << static_cast(d); + std::string out; + for (size_t i = 0; i < size; i++) { + out += std::to_string(d[i]); + out += ","; + } + VLOG(1) << out; + } + + void Compute(const framework::ExecutionContext& context) const override { + // get necessary inputs + const Tensor* logits = context.Input("Logits"); + const Tensor* label = context.Input("Label"); + VLOG(3) << "Enter SampleLogitsCUDAKernel"; + + // get necessary outputs + Tensor* samples = context.Output("Samples"); + Tensor* probabilities = context.Output("Probabilities"); + Tensor* sampled_logits = context.Output("SampledLogits"); + Tensor* sampled_label = context.Output("SampledLabel"); + + // shapes + const auto batch_size = logits->dims()[0]; + const auto num_classes = logits->dims()[1]; + const auto label_dim = label->dims(); + const auto num_true = label_dim[1]; + const auto samples_dim = samples->dims(); + + // attrs + const auto num_samples = context.Attr("num_samples"); + const bool use_custom_samples = context.Attr("use_custom_samples"); + const bool uniq = context.Attr("uniq"); + const bool remove_accidental_hits = + context.Attr("remove_accidental_hits"); + + // device contexts + auto& dev_ctx = context.cuda_device_context(); + + // UNDERSTAND: allocate memories for temporaries + sampled_logits->mutable_data(samples_dim, context.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, sampled_logits, static_cast(0)); + + auto sampled_label_data = + sampled_label->mutable_data(label_dim, context.GetPlace()); + int threads = 512; + size_t size = batch_size * num_true; + int grid = (size + threads - 1) / threads; + GPUSetLabel< + T><<>>( + size, num_true, sampled_label_data); + + if (use_custom_samples) { + const Tensor* custom_samples = context.Input("CustomSamples"); + const Tensor* custom_probabilities = + context.Input("CustomProbabilities"); + samples->ShareDataWith(*custom_samples); + probabilities->ShareDataWith(*custom_probabilities); + } else { + samples->mutable_data(context.GetPlace()); + probabilities->mutable_data(samples_dim, context.GetPlace()); + // UNDERSTAND: sampling + const auto seed = context.Attr("seed"); + auto sampler_with_prob = math::GPUSampleWithProb(); + Print(*samples, std::string("samples1")); + sampler_with_prob(context.cuda_device_context(), seed, num_classes, uniq, + num_samples, label, samples, probabilities); + } + Print(*samples, std::string("samples2")); + Print(*probabilities, std::string("probabilities")); + + // UNDERSTAND: gather sampled logits and remove accidental hits if needed + const auto num_take = samples->dims()[1]; + const auto array_dims = logits->dims(); + const auto idx_dims = samples->dims(); + + const T* p_array = logits->data(); + const int64_t* p_index = samples->data(); + T* p_value = sampled_logits->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + // index slice size + const auto idx_slice_size = idx_dims[1]; + + size = batch_size * num_take; + grid = (size + threads - 1) / threads; + GPUTakeAlongD1< + T><<>>( + size, batch_size, array_slice_size, idx_slice_size, p_array, p_index, + p_value); + Print(*sampled_logits, std::string("sampled_logits")); + + if (remove_accidental_hits) { + const size_t size = batch_size * (num_true + num_samples); + int grid = (size + threads - 1) / threads; + gpu_compute_remove_accidental_hits< + T><<>>( + size, num_true, idx_slice_size, p_index, p_value); + Print(*sampled_logits, + std::string("sampled_logits_remove_accidental_hits")); + } + + // subtracted sampled logits with logQ(y|x) + auto probs = EigenMatrix::From(*probabilities); + auto smp_logits = EigenMatrix::From(*sampled_logits); + smp_logits.device(*dev_ctx.eigen_device()) = + (smp_logits - probs.log().unaryExpr(TolerableValue())) + .unaryExpr(TolerableValue()); + Print(*sampled_logits, std::string("sampled_logits_res")); + } +}; + +template +class SampleLogitsGradCUDAKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + template + void Print(const Tensor& t, std::string name) const { + if (!FLAGS_debug_print) { + return; + } + VLOG(1) << "qxz print " << name; + VLOG(1) << name << "size = " << t.numel(); + size_t size = t.numel(); + const type* d = t.data(); +#ifdef PADDLE_WITH_CUDA + std::vector vec; + platform::DeviceContextPool::Instance().Get(t.place())->Wait(); + if (platform::is_gpu_place(t.place())) { + vec.resize(size); + cudaMemcpy(vec.data(), d, sizeof(T) * size, cudaMemcpyDeviceToHost); + d = vec.data(); + } +#endif + VLOG(1) << name << " data_ptr = " << static_cast(d); + std::string out; + for (size_t i = 0; i < size; i++) { + out += std::to_string(d[i]); + out += ","; + } + VLOG(1) << out; + } + + void Compute(const framework::ExecutionContext& context) const override { + auto logits_grad = context.Output(framework::GradVarName("Logits")); + const Tensor* samples = context.Input("Samples"); + const Tensor* sampled_logits_grad = + context.Input(framework::GradVarName("SampledLogits")); + logits_grad->mutable_data(context.GetPlace()); + + auto& dev_ctx = context.cuda_device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, logits_grad, static_cast(0)); + + // UNDERSTAND: scatter it back to logit_grad + const auto batch_size = samples->dims()[0]; + const auto num_put = samples->dims()[1]; + const auto array_dims = logits_grad->dims(); + const auto idx_dims = samples->dims(); + + T* p_array = logits_grad->data(); + const int64_t* p_index = samples->data(); + const T* p_value = sampled_logits_grad->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + // index slice size + const auto idx_slice_size = idx_dims[1]; + + int threads = 128; + const size_t size = batch_size; + int grid = (size + threads - 1) / threads; + + Print(*sampled_logits_grad, std::string("sampled_logits_grad")); + Print(*samples, std::string("samples")); + GPUPutAlongD1< + T><<>>( + size, batch_size, array_slice_size, idx_slice_size, p_array, p_index, + p_value); + Print(*logits_grad, std::string("logits_grad")); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(sample_logits, ops::SampleLogitsCUDAKernel, + ops::SampleLogitsCUDAKernel); +REGISTER_OP_CUDA_KERNEL(sample_logits_grad, + ops::SampleLogitsGradCUDAKernel, + ops::SampleLogitsGradCUDAKernel); diff --git a/paddle/fluid/operators/sample_logits_op.h b/paddle/fluid/operators/sample_logits_op.h new file mode 100644 index 000000000..77d66a642 --- /dev/null +++ b/paddle/fluid/operators/sample_logits_op.h @@ -0,0 +1,275 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sample_prob.h" +#include "paddle/fluid/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +// UNDERSTAND: something like take_along_axis in numpy. +template +static void CPUTakeAlongD1(const platform::DeviceContext& ctx, + const framework::Tensor& array, + const framework::Tensor& index, + framework::Tensor* value) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); + // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K) + PADDLE_ENFORCE(index.dims().size() == 2 && array.dims().size() == 2 && + index.dims()[0] == array.dims()[0] && + index.dims() == value->dims()); + + const auto batch_size = index.dims()[0]; + const auto num_take = index.dims()[1]; + const auto array_dims = array.dims(); + const auto idx_dims = index.dims(); + + // UNDERSTAND: no allocations here + const T* p_array = array.data(); + const int64_t* p_index = index.data(); + T* p_value = value->data(); + + // src slice size + const auto array_slice_size = array_dims[1]; + + // index slice size + const auto idx_slice_size = idx_dims[1]; + const auto value_slice_size = idx_slice_size; + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_take; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_value[i * value_slice_size + j] = + p_array[i * array_slice_size + array_index]; + } + } +} + +// UNDERSTAND: something like put_along_axis in numpy but if there is duplicate +// indices, scatter is done in += way. +template +static void CPUPutAlongD1(const platform::DeviceContext& ctx, + framework::Tensor* array, + const framework::Tensor& index, + const framework::Tensor& value) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); + // UNDERSTAND: check shape src(B, C), index(B, K), out should also be (B, K) + PADDLE_ENFORCE(index.dims().size() == 2 && array->dims().size() == 2 && + index.dims()[0] == array->dims()[0] && + index.dims() == value.dims()); + const auto batch_size = index.dims()[0]; + const auto num_put = index.dims()[1]; + auto array_dims = array->dims(); + auto idx_dims = index.dims(); + + // UNDERSTAND: no allocations here + T* p_array = array->data(); + const int64_t* p_index = index.data(); + const T* p_value = value.data(); + + // slice sizes + const auto array_slice_size = array_dims[1]; + const auto idx_slice_size = idx_dims[1]; + const auto value_slice_size = idx_slice_size; + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_put; ++j) { + auto array_index = p_index[i * idx_slice_size + j]; + p_array[i * array_slice_size + array_index] += + p_value[i * value_slice_size + j]; + } + } +} + +// UNDERSTAND: compute accidentdal hits from samples and minus corresponding +// logits by a float max, here 1e20 +template +static void compute_remove_accidental_hits(const platform::DeviceContext& ctx, + framework::Tensor* sampled_logits, + const framework::Tensor& samples, + const int num_true) { + const auto batch_size = sampled_logits->dims()[0]; + const auto num_sampled_classes = sampled_logits->dims()[1]; + T* sampled_logits_data = sampled_logits->data(); + const auto samples_data = samples.data(); + + std::unordered_set tmp_true_labels; + for (int i = 0; i < batch_size; ++i) { + tmp_true_labels.clear(); + tmp_true_labels.insert(samples_data + i * num_sampled_classes, + samples_data + i * num_sampled_classes + num_true); + for (int j = num_true; j < num_sampled_classes; ++j) { + const auto idx = i * num_sampled_classes + j; + if (tmp_true_labels.find(samples_data[idx]) != tmp_true_labels.end()) + sampled_logits_data[idx] -= 1e20; + } + } +} + +template +class SampleLogitsKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), + "This kernel only runs on CPU."); + VLOG(3) << "Enter SampleLogitsKernel"; + // get necessary inputs + const Tensor* logits = context.Input("Logits"); + const Tensor* label = context.Input("Label"); + + // get necessary outputs + Tensor* samples = context.Output("Samples"); + Tensor* probabilities = context.Output("Probabilities"); + Tensor* sampled_logits = context.Output("SampledLogits"); + Tensor* sampled_label = context.Output("SampledLabel"); + + // shapes + const auto batch_size = logits->dims()[0]; + const auto num_classes = logits->dims()[1]; + const auto label_dim = label->dims(); + const auto num_true = label_dim[1]; + const auto samples_dim = samples->dims(); + + // attrs + const auto num_samples = context.Attr("num_samples"); + const bool use_custom_samples = context.Attr("use_custom_samples"); + const bool remove_accidental_hits = + context.Attr("remove_accidental_hits"); + + // device contexts + auto& dev_ctx = + context.template device_context(); + + // UNDERSTAND: allocate memories for temporaries + sampled_logits->mutable_data(samples_dim, context.GetPlace()); + auto sampled_label_data = + sampled_label->mutable_data(label_dim, context.GetPlace()); + for (int i = 0; i < batch_size; ++i) + for (int j = 0; j < num_true; ++j) + sampled_label_data[i * num_true + j] = j; + + if (use_custom_samples) { + const Tensor* custom_samples = context.Input("CustomSamples"); + const Tensor* custom_probabilities = + context.Input("CustomProbabilities"); + samples->ShareDataWith(*custom_samples); + probabilities->ShareDataWith(*custom_probabilities); + } else { + samples->mutable_data(context.GetPlace()); + probabilities->mutable_data(samples_dim, context.GetPlace()); + // UNDERSTAND: sampling + const auto seed = context.Attr("seed"); + auto sampler_with_prob = + math::SampleWithProb(); + sampler_with_prob(dev_ctx, math::LogUniformSampler(num_classes, seed), + num_samples, label, samples, probabilities); + } + + // UNDERSTAND: gather sampled logits and remove accidental hits if needed + CPUTakeAlongD1(dev_ctx, *logits, *samples, sampled_logits); + if (remove_accidental_hits) { + compute_remove_accidental_hits(dev_ctx, sampled_logits, *samples, + num_true); + } + + /* Debug + const auto num_sampled_classes = samples_dim[1]; + std::cout << "Sampled Logits" << std::endl; + const auto sampled_logits_data = sampled_logits->data(); + for (int i = 0; i < sampled_logits->numel(); ++i) { + std::cout << sampled_logits_data[i] << ", "; + if ((i + 1) % num_sampled_classes == 0) + std::cout << std::endl; + } + std::cout << std::endl; + */ + /* Debug + std::cout << "Samples" << std::endl; + const auto samples_data = samples->data(); + for (int i = 0; i < samples->numel(); ++i) { + std::cout << samples_data[i] << ", "; + if ((i + 1) % num_sampled_classes == 0) + std::cout << std::endl; + } + std::cout << std::endl; + */ + /* Debug + std::cout << "Probabilities" << std::endl; + const auto probabilities_data = probabilities->data(); + for (int i = 0; i < probabilities->numel(); ++i) { + std::cout << probabilities_data[i] << ", "; + if ((i + 1) % num_sampled_classes == 0) + std::cout << std::endl; + } + std::cout << std::endl; + */ + // subtracted sampled logits with logQ(y|x) + auto probs = EigenMatrix::From(*probabilities); + auto smp_logits = EigenMatrix::From(*sampled_logits); + smp_logits.device(*dev_ctx.eigen_device()) = + (smp_logits - probs.log().unaryExpr(TolerableValue())) + .unaryExpr(TolerableValue()); + } +}; + +template +class SampleLogitsGradKernel : public framework::OpKernel { + public: + using Tensor = framework::Tensor; + void Compute(const framework::ExecutionContext& context) const override { + auto logits_grad = context.Output(framework::GradVarName("Logits")); + const Tensor* samples = context.Input("Samples"); + const Tensor* sampled_logits_grad = + context.Input(framework::GradVarName("SampledLogits")); + logits_grad->mutable_data(context.GetPlace()); + + auto& dev_ctx = + context.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, logits_grad, static_cast(0)); + + // const bool remove_accidental_hits = + // context.Attr("remove_accidental_hits"); + + // UNDERSTAND: scatter it back to logit_grad + CPUPutAlongD1(dev_ctx, logits_grad, *samples, *sampled_logits_grad); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 564882bd2..896d98c97 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -131,7 +131,7 @@ def __bootstrap__(): 'eager_delete_tensor_gb', 'fast_eager_deletion_mode', 'allocator_strategy', 'reader_queue_speed_test_mode', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir', - 'inner_op_parallelism', 'enable_parallel_graph' + 'inner_op_parallelism', 'enable_parallel_graph', 'debug_print' ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0e4b5aadc..8b033aa6b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -87,6 +87,7 @@ __all__ = [ 'transpose', 'im2sequence', 'nce', + 'sample_logits', 'hsigmoid', 'beam_search', 'row_conv', @@ -5764,6 +5765,104 @@ def softmax_with_cross_entropy(logits, return loss +def sample_logits(logits, + label, + num_samples, + uniq=True, + remove_accidental_hits=True, + use_custom_samples=False, + custom_samples=None, + custom_probabilities=None, + seed=0): + """ + **Sampled Softmax With Cross Entropy Operator.** + + Cross entropy loss with sampled softmax is used as the output layer for + larger output classes extensively. This operator samples a number of samples + for each example(row), and computes the softmax normalized values for each + row of the sampled tensor, after which cross-entropy loss is computed. + This provides a more numerically stable gradient. + + Because this operator performs a softmax on logits internally, it expects + unscaled logits. This operator should not be used with the output of + softmax operator since that would produce incorrect results. + + For examples with T true labels (T >= 1), we assume that each true label has + a probability of 1/T. For each sample, S samples are generated using a + log uniform distribution. True labels are concatenated with hese samples to + form T + S samples for each example. So, assume the shape of logits is + [N x K], the shape for samples is [N x (T+S)]. For each sampled label, a + probability is calculated, which corresponds to the Q(y|x) in + [Jean et al., 2014](http://arxiv.org/abs/1412.2007). + + Logits are sampled according to the sampled labels. Then if + remove_accidental_hits is True, if a sample[i, j] accidentally hits true + labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to + make its softmax result close to zero. Then samled logits are subtracted by + logQ(y|x), these sampled logits and re-indexed labels are used to compute + a softmax with cross entropy. + + Args: + logits (Variable): The unscaled log probabilities, which is a 2-D tensor + with shape [N x K]. N is the batch_size, and K is the class number. + label (Variable): The ground truth which is a 2-D tensor. Label is a + Tensor with shape [N x T], where T is the number of true + labels per example. + num_samples (int): The number for each example, num_samples should be + less than the number of class. + seed (int): The random seed for generating random number, which is used + in the process of sampling. Default is 0. + remove_accidental_hits (bool): A flag indicating whether to remove + accidental hits when sampling. If True and if a sample[i, j] + accidentally hits true labels, then the corresponding + sampled_logits[i, j] is minus by 1e20 to make its softmax result + close to zero. Default is True. + + Returns: + Variable: Return the cross entropy loss which is a 2-D tensor with shape + [N x 1]. + + Examples: + .. code-block:: python + + logits = fluid.layers.data(name='data', shape=[256], dtype='float32') + label = fluid.layers.data(name='label', shape=[5], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.sampled_softmax_with_cross_entropy( + logits=fc, label=label, num_samples=25) + """ + helper = LayerHelper('sample_logits', **locals()) + samples = helper.create_variable_for_type_inference(dtype='int64') + probabilities = helper.create_variable_for_type_inference( + dtype=logits.dtype) + sampled_logits \ + = helper.create_variable_for_type_inference(dtype=logits.dtype) + sampled_label = helper.create_variable_for_type_inference(dtype='int64') + + helper.append_op( + type='sample_logits', + inputs={ + 'Logits': logits, + 'Label': label, + 'CustomSamples': custom_samples, + 'CustomProbabilities': custom_probabilities + }, + outputs={ + 'Samples': samples, + 'Probabilities': probabilities, + 'SampledLabel': sampled_label, + 'SampledLogits': sampled_logits + }, + attrs={ + 'use_custom_samples': use_custom_samples, + 'uniq': uniq, + 'remove_accidental_hits': remove_accidental_hits, + 'num_samples': num_samples, + 'seed': seed + }) + return sampled_logits, sampled_label, samples, probabilities + + def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): """ This layer computes the smooth L1 loss for Variable :attr:`x` and :attr:`y`. diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0fe836683..2d15768c0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -350,6 +350,7 @@ class OpTest(unittest.TestCase): actual_t = np.array(actual) expect = self.outputs[out_name] expect_t = expect[0] if isinstance(expect, tuple) else expect + #import pdb; pdb.set_trace() self.assertTrue( np.allclose( actual_t, expect_t, atol=atol, equal_nan=equal_nan), diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e7bc1601a..7f7a51d9d 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -374,6 +374,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_sample_logits(self): + program = Program() + with program_guard(program): + logits = layers.data(name='Logits', shape=[256], dtype='float64') + label = layers.data(name='Label', shape=[5], dtype='int64') + num_samples = 25 + output = layers.sample_logits(logits, label, num_samples) + self.assertIsNotNone(output) + print(str(program)) + @decorators.prog_scope() def test_nce(self): window_size = 5 diff --git a/python/paddle/fluid/tests/unittests/test_sample_logits.py b/python/paddle/fluid/tests/unittests/test_sample_logits.py new file mode 100644 index 000000000..b36694f11 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sample_logits.py @@ -0,0 +1,1233 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class Sampler(object): + def __init__(self, range, seed): + self.range_ = range + self.seed_ = seed + np.random.seed(self.seed_) + + def sample(self): + rasie("No Implementation!") + + def probability(self, value): + raise ("No Implementation!") + + +class LogUniformSampler(Sampler): + def __init__(self, range, seed): + super(LogUniformSampler, self).__init__(range, seed) + self.log_range_ = np.log(self.range_ + 1) + + def sample(self): + value = int(np.exp(np.random.uniform(0.0, self.log_range_)) - 1) + return value % self.range_ + + def probability(self, value): + return np.log((value + 2.0) / (value + 1.0)) / self.log_range_ + + +def adjust_prob(prob, num_samples, num_tries): + if num_samples == num_tries: + return prob * num_samples + else: + return -np.expm1(num_tries * np.log1p(-prob)) + + +def take_along_axis1(array, index): + out = np.zeros_like(index, dtype=array.dtype) + n_row, n_col = index.shape + for i in range(n_row): + for j in range(n_col): + out[i, j] = array[i, index[i, j]] + return out + + +def sample_prob(sampler, num_samples, label): + batch_size, num_true = label.shape + num_sampled_classes = num_samples + num_true + + samples = np.zeros((batch_size, num_sampled_classes), dtype=np.int64) + probabilities = np.zeros( + (batch_size, num_sampled_classes), dtype=np.float64) + + tmp_samples = set() + num_tries = 0 + j = 0 + while j < num_true: + for i in range(batch_size): + samples[i, j] = label[i, j] + probabilities[i, j] = sampler.probability(label[i, j]) + j += 1 + while j < num_sampled_classes: + v = sampler.sample() + num_tries += 1 + if v not in tmp_samples: + tmp_samples.add(v) + for i in range(batch_size): + samples[i, j] = v + probabilities[i, j] = sampler.probability(v) + j += 1 + for k in range(num_sampled_classes): + for i in range(batch_size): + probabilities[i, k] = adjust_prob(probabilities[i, k], num_samples, + num_tries) + return (samples, probabilities) + + +def compute_remove_accidental_hits(sampled_logits, samples, num_true): + batch_size, num_sampled_classes = samples.shape + for i in range(batch_size): + true_labels = set(samples[i, np.arange(num_true)]) + for j in range(num_true, num_sampled_classes): + if samples[i, j] in true_labels: + sampled_logits[i, j] -= 1e20 + + +def sample_logits(logits, + label, + num_samples, + seed, + remove_accidental_hits, + use_custom_samples, + custom_samples=None, + custom_probabilities=None): + batch_size, num_classes = logits.shape + num_true = label.shape[1] + num_sampled_classes = num_true + num_samples + + if use_custom_samples: + samples = custom_samples + probabilities = custom_probabilities + else: + sampler = LogUniformSampler(num_classes, seed) + samples, probabilities = sample_prob(sampler, num_samples, label) + sampled_logits = take_along_axis1(logits, samples) + + #print(samples) + #print(probabilities) + #print(sampled_logits) + if remove_accidental_hits: + compute_remove_accidental_hits(sampled_logits, samples, num_true) + sampled_logits -= np.log(probabilities) + sampled_label = np.tile(np.arange(num_true), (batch_size, 1)) + return (sampled_logits, samples, sampled_label, probabilities) + + +class TestSampleLogitsOp(OpTest): + ''' + Test SampleLogitsOp, but with random results precomputed + in python and just test the non-random part. + ''' + + def generate_data(self, logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples, + custom_samples, custom_probabilities): + self.attrs = { + 'num_samples': num_samples, + 'use_custom_samples': use_custom_samples, + 'remove_accidental_hits': remove_accidental_hits, + 'seed': seed + } + self.inputs = { + 'Logits': logits, + 'Label': label, + 'CustomSamples': custom_samples, + 'CustomProbabilities': custom_probabilities + } + + def set_data(self, batch_size, num_classes, num_true, num_samples, seed, + remove_accidental_hits): + logits = np.random.randn(batch_size, num_classes) + label = np.stack([ + np.random.choice( + range(0, num_classes), num_true, replace=False) + for _ in range(batch_size) + ]) + sampler = LogUniformSampler(num_classes, seed) + custom_samples, custom_probabilities = \ + sample_prob(sampler, num_samples, label) + use_custom_samples = True + remove_accidental_hits = remove_accidental_hits + self.generate_data(logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples, + custom_samples, custom_probabilities) + + def compute(self): + out = sample_logits(self.inputs["Logits"], self.inputs["Label"], + self.attrs["num_samples"], self.attrs["seed"], + self.attrs["remove_accidental_hits"], + self.attrs["use_custom_samples"], + self.inputs["CustomSamples"], + self.inputs["CustomProbabilities"]) + + self.outputs = { + 'SampledLogits': out[0], + 'Samples': out[1], + 'SampledLabel': out[2], + 'Probabilities': out[3] + } + + def setUp(self): + self.op_type = 'sample_logits' + batch_size = 5 + num_classes = 20 + num_true = 5 + num_samples = 10 + seed = 10 + remove_accidental_hits = True + self.set_data(batch_size, num_classes, num_true, num_samples, seed, + remove_accidental_hits) + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + self.check_grad( + ["Logits"], ["SampledLogits", "Samples"], max_relative_error=0.02) + + +class TestSampleLogitsOp2(TestSampleLogitsOp): + def setUp(self): + self.op_type = 'sample_logits' + batch_size = 5 + num_classes = 20 + num_true = 5 + num_samples = 10 + seed = 10 + remove_accidental_hits = False + self.set_data(batch_size, num_classes, num_true, num_samples, seed, + remove_accidental_hits) + self.compute() + + +class TestSampleLogitsOp3(TestSampleLogitsOp): + def setUp(self): + self.op_type = 'sample_logits' + batch_size = 5 + num_classes = 100 + num_true = 5 + num_samples = 25 + seed = 10 + remove_accidental_hits = True + self.set_data(batch_size, num_classes, num_true, num_samples, seed, + remove_accidental_hits) + self.compute() + + +class TestSampleLogitsOp4(TestSampleLogitsOp): + def setUp(self): + self.op_type = 'sample_logits' + batch_size = 5 + num_classes = 100 + num_true = 5 + num_samples = 25 + seed = 10 + remove_accidental_hits = False + self.set_data(batch_size, num_classes, num_true, num_samples, seed, + remove_accidental_hits) + self.compute() + + +class TestSampleLogitsOpV2(OpTest): + ''' + Test SampleLogitsOp, but with random results precomputed + in C++ and copied to python and just test the non-random part. + ''' + + def generate_data(self, logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples): + self.attrs = { + 'num_samples': num_samples, + 'use_custom_samples': use_custom_samples, + 'remove_accidental_hits': remove_accidental_hits, + 'seed': seed + } + self.inputs = {'Logits': logits, 'Label': label} + + def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): + label = np.array([[6, 12, 15, 5, 1], [0, 9, 4, 1, 10], + [0, 2, 10, 16, 13], [14, 4, 7, 2, 1], + [3, 18, 11, 8, 14]]) + batch_size, num_true = label.shape + use_custom_samples = False + + num_sampled_classes = num_samples + num_true + logits = np.random.randn(batch_size, num_classes) + + remove_accidental_hits = remove_accidental_hits + self.generate_data(logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples) + + # python and c++ use different random generator + # use fetched samples from c++ for python code + self.fetched_samples = np.array( + [[6, 12, 15, 5, 1, 5, 15, 1, 0, 8, 3, 14, 2, 13, 4], + [0, 9, 4, 1, 10, 5, 15, 1, 0, 8, 3, 14, 2, 13, 4], + [0, 2, 10, 16, 13, 5, 15, 1, 0, 8, 3, 14, 2, 13, 4], + [14, 4, 7, 2, 1, 5, 15, 1, 0, 8, 3, 14, 2, 13, 4], + [3, 18, 11, 8, 14, 5, 15, 1, 0, 8, 3, 14, 2, 13, 4]]) + fectched_num_tries = 21 + + probabilities = np.zeros( + (batch_size, num_sampled_classes), dtype=np.float64) + + sampler = LogUniformSampler(num_classes, seed) + for j in range(num_sampled_classes): + for i in range(batch_size): + probabilities[i, j] = sampler.probability(self.fetched_samples[ + i, j]) + probabilities[i, j] = adjust_prob( + probabilities[i, j], num_samples, fectched_num_tries) + self.probabilities = probabilities + + def compute(self): + out = sample_logits(self.inputs["Logits"], self.inputs["Label"], + self.attrs["num_samples"], self.attrs["seed"], + self.attrs["remove_accidental_hits"], True, + self.fetched_samples, self.probabilities) + self.outputs = { + 'SampledLogits': out[0], + 'Samples': out[1], + 'SampledLabel': out[2], + 'Probabilities': out[3] + } + + def setUp(self): + self.op_type = 'sample_logits' + num_samples = 10 + num_classes = 20 + seed = 10 + remove_accidental_hits = True + + self.set_data(num_classes, num_samples, seed, remove_accidental_hits) + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + self.check_grad( + ["Logits"], ["SampledLogits", "Samples"], max_relative_error=0.02) + + +class TestSampleLogitsOpV3(OpTest): + ''' + Test SampleLogitsOp, but with random results precomputed + in C++ and copied to python and just test the non-random part. + ''' + + def generate_data(self, logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples): + self.attrs = { + 'num_samples': num_samples, + 'use_custom_samples': use_custom_samples, + 'remove_accidental_hits': remove_accidental_hits, + 'seed': seed + } + self.inputs = {'Logits': logits, 'Label': label} + + def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): + self.fetched_samples = np.array([[ + 52, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 2, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 2, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 17, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 96, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 2, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 17, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 96, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 37, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ], [ + 2, + 3, + 12, + 74, + 28, + 1, + 79, + 2, + 42, + 8, + 13, + 0, + 18, + 88, + 49, + 14, + 46, + 39, + 57, + 26, + 75, + 9, + 50, + 16, + 66, + 6, + 23, + 5, + 11, + 17, + 54, + 35, + 20, + 53, + 10, + 47, + 80, + 38, + 7, + 4, + 31, + 15, + 19, + 58, + 22, + 34, + 41, + 73, + 62, + 95, + 25, + 70, + 37, + 30, + 65, + 27, + 51, + 43, + 32, + 99, + 21, + 56, + 29, + 40, + 69, + 55, + 98, + 77, + 67, + 33, + 89, + 63, + 81, + 59, + 48, + 91, + 68, + 72, + 61, + 52, + 86, + ]]) + fectched_num_tries = 323 + + label = self.fetched_samples[:, 0:1] + batch_size, num_true = label.shape + use_custom_samples = False + + #import pdb; pdb.set_trace() + num_sampled_classes = num_samples + num_true + logits = np.random.randn(batch_size, num_classes) + + remove_accidental_hits = remove_accidental_hits + self.generate_data(logits, label, num_samples, seed, + remove_accidental_hits, use_custom_samples) + + # python and c++ use different random generator + # use fetched samples from c++ for python code + probabilities = np.zeros( + (batch_size, num_sampled_classes), dtype=np.float64) + + sampler = LogUniformSampler(num_classes, seed) + for j in range(num_sampled_classes): + for i in range(batch_size): + probabilities[i, j] = sampler.probability(self.fetched_samples[ + i, j]) + probabilities[i, j] = adjust_prob( + probabilities[i, j], num_samples, fectched_num_tries) + self.probabilities = probabilities + + def compute(self): + out = sample_logits(self.inputs["Logits"], self.inputs["Label"], + self.attrs["num_samples"], self.attrs["seed"], + self.attrs["remove_accidental_hits"], True, + self.fetched_samples, self.probabilities) + self.outputs = { + 'SampledLogits': out[0], + 'Samples': out[1], + 'SampledLabel': out[2], + 'Probabilities': out[3] + } + + def setUp(self): + self.op_type = 'sample_logits' + num_samples = 80 + num_classes = 100 + seed = 123 + remove_accidental_hits = True + + self.set_data(num_classes, num_samples, seed, remove_accidental_hits) + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + self.check_grad( + ["Logits"], ["SampledLogits", "Samples"], max_relative_error=0.02) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index c4eb26893..1fe62fa4a 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -156,8 +156,26 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): return var_dict +def var_cast(block, input): + if input.dtype == core.VarDesc.VarType.FP32 or input.dtype == core.VarDesc.VarType.FP32: + return input + out = block.create_var(dtype="float32", shape=[1]) + op = block.append_op( + inputs={"X": input}, + outputs={"Out": out}, + type='cast', + attrs={ + 'out_dtype': core.VarDesc.VarType.FP32, + 'in_dtype': input.dtype + }) + op.desc.infer_var_type(block.desc) + op.desc.infer_shape(block.desc) + return out + + def append_loss_ops(block, output_names): mean_inputs = list(map(block.var, output_names)) + mean_inputs = [var_cast(block, x) for x in mean_inputs] if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) -- GitLab