提交 da2cc99f 编写于 作者: T tangwei12

sampling op optimize

上级 4973e07b
...@@ -25,9 +25,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { ...@@ -25,9 +25,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of RowConvOp should not be null."); "Input(X) of SamplingIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RowConvOp should not be null."); "Output(Out) of SamplingIdOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
...@@ -43,8 +43,7 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,8 +43,7 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "Sliced data tensor."); AddOutput("Out", "SamplingId data tensor.");
AddComment(R"DOC( AddComment(R"DOC(
SamplingId Operator. SamplingId Operator.
@brief A layer for sampling id from multinomial distribution from the @brief A layer for sampling id from multinomial distribution from the
......
...@@ -16,19 +16,6 @@ limitations under the License. */ ...@@ -16,19 +16,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/operators/sampling_id_op.h" #include "paddle/fluid/operators/sampling_id_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class SamplingIdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sampling_id, sampling_id,
......
...@@ -50,7 +50,7 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
out_dim.push_back(static_cast<int64_t>(batch_size)); out_dim.push_back(static_cast<int64_t>(batch_size));
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Out");
output->Resize(framework::make_ddim(out_dim)); output->Resize(framework::make_ddim(out_dim));
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(ids, context.device_context(), output); framework::TensorFromVector(ids, context.device_context(), output);
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestSamplingIdOp(OpTest):
def setUp(self):
self.op_type = "sampling_id"
self.use_mkldnn = False
self.init_kernel_type()
X = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": X}
Y = np.random.random(3).astype('float32')
self.outputs = {'Out': Y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_kernel_type(self):
pass
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册