提交 6dd3a61b 编写于 作者: L Luo Tao

combine batch_size_like.cc into batch_size_like.h

上级 c02f773a
...@@ -155,6 +155,7 @@ op_library(print_op DEPS lod_tensor) ...@@ -155,6 +155,7 @@ op_library(print_op DEPS lod_tensor)
op_library(adagrad_op DEPS selected_rows_functor) op_library(adagrad_op DEPS selected_rows_functor)
op_library(maxout_op DEPS maxouting) op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling) op_library(unpool_op DEPS unpooling)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op DEPS lod_rank_table) op_library(lod_rank_table_op DEPS lod_rank_table)
op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
...@@ -171,20 +172,13 @@ op_library(cos_sim_op DEPS cos_sim_functor) ...@@ -171,20 +172,13 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader) op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv) op_library(conv_op DEPS vol2col depthwise_conv)
else() else()
op_library(conv_op DEPS vol2col) op_library(conv_op DEPS vol2col)
endif() endif()
op_library(pool_op DEPS pooling)
op_library(conv_transpose_op DEPS vol2col) op_library(conv_transpose_op DEPS vol2col)
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
op_library(fill_constant_batch_size_like_op DEPS batch_size_like)
op_library(uniform_random_batch_size_like_op DEPS batch_size_like uniform_random_op)
op_library(gaussian_random_batch_size_like_op DEPS batch_size_like gaussian_random_op)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
void BatchSizeLikeOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of %s should not be null.", Type());
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of %s should not be null.",
Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
BatchSizeLikeOpMaker::BatchSizeLikeOpMaker(OpProto *proto,
OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
}
} // namespace operators
} // namespace paddle
...@@ -24,12 +24,50 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -24,12 +24,50 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override; void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of %s should not be null.", Type());
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of %s should not be null.", Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
}; };
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker); BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册