提交 8875f787 编写于 作者: C chengtbf 提交者: Will Zhang

reshape op (#547)

* reshape op

* fix bug

* add check


Former-commit-id: 2c2b79b5
上级 876255c0
......@@ -245,6 +245,12 @@ message BasicRnnOpConf {
message BasicLstmOpConf {
}
message ReshapeOpConf {
required string in = 1;
required string out = 2;
required ShapeProto shape = 3;
}
message OperatorConf {
required string name = 1;
optional string model_load_dir = 2;
......@@ -269,6 +275,7 @@ message OperatorConf {
LossPrintOpConf loss_print_conf = 119;
ReduceSumOpConf reduce_sum_conf = 120;
BasicRnnOpConf basic_rnn_conf = 121;
ReshapeOpConf reshape_conf = 122;
AveragePooling2DOpConf average_pooling_2d_conf = 200;
MaxPooling2DOpConf max_pooling_2d_conf = 201;
}
......
#include "oneflow/core/operator/reshape_op.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
void ReshapeOp::InitFromOpConf() {
CHECK(op_conf().has_reshape_conf());
EnrollInputBn("in");
EnrollOutputBn("out");
}
const PbMessage& ReshapeOp::GetSpecialConf() const {
return op_conf().reshape_conf();
}
void ReshapeOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
*out_blob_desc = *in_blob_desc;
const ReshapeOpConf& conf = op_conf().reshape_conf();
std::vector<int64_t> dim_vec(1 + conf.shape().dim_size());
dim_vec[0] = in_blob_desc->shape().At(0);
FOR_RANGE(size_t, i, 1, dim_vec.size()) {
dim_vec[i] = conf.shape().dim(i - 1);
}
out_blob_desc->mut_shape() = Shape(dim_vec);
CHECK_EQ(out_blob_desc->shape().elem_cnt(), in_blob_desc->shape().elem_cnt());
}
REGISTER_OP(OperatorConf::kReshapeConf, ReshapeOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_RESHAPE_OP_H_
#define ONEFLOW_CORE_OPERATOR_RESHAPE_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ReshapeOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ReshapeOp);
ReshapeOp() = default;
~ReshapeOp() = default;
void InitFromOpConf() override;
const PbMessage& GetSpecialConf() const override;
bool IsElemWiseOp() const override { return true; }
bool NeedExtraInDiffMemWhenBackward() const override { return false; }
bool NeedOutWhenBackward() const override { return false; }
void InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_RESHAPE_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册