未验证 提交 c857841e 编写于 作者: W WangZhen 提交者: GitHub

Adapt tensor num_samples for multinomial (#45522)

上级 51f4291c
...@@ -31,7 +31,8 @@ class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -31,7 +31,8 @@ class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "A tensor contains probabilities of categories"); AddInput("X", "A tensor contains probabilities of categories");
AddOutput("Out", "The output tensor of multinomial op"); AddOutput("Out", "The output tensor of multinomial op");
AddAttr<int>("num_samples", "number of the generated samples") AddAttr<int>("num_samples", "number of the generated samples")
.SetDefault(1); .SetDefault(1)
.SupportTensor();
AddAttr<bool>("replacement", "can a category be sampled more than once") AddAttr<bool>("replacement", "can a category be sampled more than once")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
...@@ -46,6 +47,13 @@ This OP returns a Tensor filled with the sampled categoris according to Multinom ...@@ -46,6 +47,13 @@ This OP returns a Tensor filled with the sampled categoris according to Multinom
class MultinomialOp : public framework::OperatorWithKernel { class MultinomialOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -1828,7 +1828,7 @@ ...@@ -1828,7 +1828,7 @@
optional : rois_num optional : rois_num
- api : multinomial - api : multinomial
args : (Tensor x, int num_samples, bool replacement) args : (Tensor x, Scalar num_samples, bool replacement)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : MultinomialInferMeta func : MultinomialInferMeta
......
...@@ -1905,9 +1905,11 @@ void ModeInferMeta(const MetaTensor& x, ...@@ -1905,9 +1905,11 @@ void ModeInferMeta(const MetaTensor& x,
} }
void MultinomialInferMeta(const MetaTensor& x, void MultinomialInferMeta(const MetaTensor& x,
int num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
MetaTensor* out) { MetaTensor* out,
MetaConfig config) {
auto int_num_samples = num_samples.to<int>();
auto x_dim = x.dims(); auto x_dim = x.dims();
int64_t x_rank = x_dim.size(); int64_t x_rank = x_dim.size();
PADDLE_ENFORCE_GT(x_rank, PADDLE_ENFORCE_GT(x_rank,
...@@ -1928,12 +1930,16 @@ void MultinomialInferMeta(const MetaTensor& x, ...@@ -1928,12 +1930,16 @@ void MultinomialInferMeta(const MetaTensor& x,
out_dims[i] = x_dim[i]; out_dims[i] = x_dim[i];
} }
PADDLE_ENFORCE_GT( if (config.is_runtime || !num_samples.FromTensor()) {
num_samples, PADDLE_ENFORCE_GT(int_num_samples,
0, 0,
errors::InvalidArgument( errors::InvalidArgument(
"The number of samples should be > 0, but got %d.", num_samples)); "The number of samples should be > 0, but got %d.",
out_dims[x_rank - 1] = num_samples; int_num_samples));
out_dims[x_rank - 1] = int_num_samples;
} else {
out_dims[x_rank - 1] = -1;
}
out->set_dims(make_ddim(out_dims)); out->set_dims(make_ddim(out_dims));
out->set_dtype(DataType::INT64); out->set_dtype(DataType::INT64);
......
...@@ -277,9 +277,10 @@ void ModeInferMeta(const MetaTensor& x, ...@@ -277,9 +277,10 @@ void ModeInferMeta(const MetaTensor& x,
MetaTensor* indices); MetaTensor* indices);
void MultinomialInferMeta(const MetaTensor& x, void MultinomialInferMeta(const MetaTensor& x,
int num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
MetaTensor* out); MetaTensor* out,
MetaConfig config = MetaConfig());
void NanmedianInferMeta(const MetaTensor& x, void NanmedianInferMeta(const MetaTensor& x,
const IntArray& axes, const IntArray& axes,
......
...@@ -23,7 +23,7 @@ namespace phi { ...@@ -23,7 +23,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx, void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
DenseTensor* out) { DenseTensor* out) {
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
...@@ -36,7 +36,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -36,7 +36,7 @@ void MultinomialKernel(const Context& dev_ctx,
funcs::MultinomialFunctor<T>(dev_ctx, funcs::MultinomialFunctor<T>(dev_ctx,
out_data, out_data,
in_data, in_data,
num_samples, num_samples.to<int>(),
replacement, replacement,
num_categories, num_categories,
num_distributions); num_distributions);
......
...@@ -128,9 +128,10 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -128,9 +128,10 @@ __global__ void sampleMultinomialWithReplacement(
template <typename T, typename Context> template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx, void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
DenseTensor* out) { DenseTensor* out) {
auto int_num_samples = num_samples.to<int>();
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out); int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
auto in_dims = x.dims(); auto in_dims = x.dims();
...@@ -172,7 +173,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -172,7 +173,7 @@ void MultinomialKernel(const Context& dev_ctx,
} }
int valid_samples = num_categories - zero_num; int valid_samples = num_categories - zero_num;
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
num_samples, int_num_samples,
valid_samples, valid_samples,
errors::InvalidArgument("When replacement=False, 'num_samples' " errors::InvalidArgument("When replacement=False, 'num_samples' "
"must less than or eaqual to the number of " "must less than or eaqual to the number of "
...@@ -191,14 +192,14 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -191,14 +192,14 @@ void MultinomialKernel(const Context& dev_ctx,
rand_data[idx] = in_data[idx] / rand_data[idx]; rand_data[idx] = in_data[idx] / rand_data[idx];
}); });
if (num_samples == 1) { if (int_num_samples == 1) {
ArgMaxKernel<T, Context>( ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
} else { } else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims()); std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec)); DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
TopkKernel<T, Context>( TopkKernel<T, Context>(
dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out); dev_ctx, rand, num_samples, -1, true, true, &value, out);
} }
return; return;
} }
...@@ -268,7 +269,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -268,7 +269,7 @@ void MultinomialKernel(const Context& dev_ctx,
int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id); const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
int grid_y = std::min<int64_t>(num_distributions, prop.maxGridSize[1]); int grid_y = std::min<int64_t>(num_distributions, prop.maxGridSize[1]);
dim3 grid((num_samples - 1) / block.x + 1, grid_y); dim3 grid((int_num_samples - 1) / block.x + 1, grid_y);
auto gen_cuda = dev_ctx.GetGenerator(); auto gen_cuda = dev_ctx.GetGenerator();
size_t curand4_loop_times = size_t curand4_loop_times =
...@@ -278,7 +279,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -278,7 +279,7 @@ void MultinomialKernel(const Context& dev_ctx,
auto seed_offset = gen_cuda->IncrementOffset(increment); auto seed_offset = gen_cuda->IncrementOffset(increment);
sampleMultinomialWithReplacement<T> sampleMultinomialWithReplacement<T>
<<<grid, block, 0, dev_ctx.stream()>>>(num_samples, <<<grid, block, 0, dev_ctx.stream()>>>(int_num_samples,
out_data, out_data,
num_distributions, num_distributions,
num_categories, num_categories,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
...@@ -21,7 +22,7 @@ namespace phi { ...@@ -21,7 +22,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx, void MultinomialKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int num_samples, const Scalar& num_samples,
bool replacement, bool replacement,
DenseTensor* out); DenseTensor* out);
......
...@@ -21,6 +21,8 @@ from paddle.fluid import core ...@@ -21,6 +21,8 @@ from paddle.fluid import core
from op_test import OpTest from op_test import OpTest
import numpy as np import numpy as np
import os import os
from paddle.fluid import Program, program_guard
from test_attribute_var import UnittestBase
def sample_output_one_dimension(out, dim): def sample_output_one_dimension(out, dim):
...@@ -294,5 +296,47 @@ class TestRandomValue(unittest.TestCase): ...@@ -294,5 +296,47 @@ class TestRandomValue(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestMultinomialTensorNumSamples(UnittestBase):
def init_info(self):
self.shapes = [[3, 4]]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def path_prefix(self):
return 'multinomial_tensor_num'
def var_prefix(self):
return "Var["
def call_func(self, x):
num_samples = paddle.assign(3)
out = paddle.multinomial(x, num_samples)
return out
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([3, 4])
x.stop_gradient = False
feat = fc(x)
out = self.call_func(paddle.abs(feat))
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
np.testing.assert_equal(res[1].shape, (3, 3))
# Test for Inference Predictor
infer_outs = self.infer_prog()
np.testing.assert_equal(infer_outs[1].shape, (3, 3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册