diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 8870b1cd7acfe4074cfb0b5484b5ab852bba8659..c64be2170c4a3a874681fb8293b591d41e853158 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -128,7 +128,6 @@ message ReluOpConf { message SoftmaxOpConf { string in = 1; string out = 2; - int32 axis = 3; } message MultinomialLogisticLossOpConf { diff --git a/oneflow/core/operator/softmax_op.cpp b/oneflow/core/operator/softmax_op.cpp index 2fcd5e4c3261211948de899c66f41f2e3051fc9f..016586853a3fb56970b5d1f9d34eb1150d04fc82 100644 --- a/oneflow/core/operator/softmax_op.cpp +++ b/oneflow/core/operator/softmax_op.cpp @@ -8,6 +8,7 @@ void SoftmaxOp::InitFromOpConf(const OperatorConf& op_conf) { EnrollInputBn("in"); EnrollOutputBn("out"); + EnrollDataTmpBn("tmp_max"); } const PbMessage& SoftmaxOp::GetSpecialConf() const { @@ -18,10 +19,9 @@ void SoftmaxOp::InferShape4FwBlobs( std::function GetShapePtr4BnInOp, ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const { std::vector vec = GetShapePtr4BnInOp(SoleIbn())->dim_vec(); - CHECK_GT(vec.size(), 1); - int32_t axis = (op_conf().softmax_conf().axis() + vec.size()) % vec.size(); - vec.erase(vec.begin() + axis); + CHECK_GE(vec.size(), 2); *GetShapePtr4BnInOp(SoleObn()) = Shape(vec); + *GetShapePtr4BnInOp(SoleDtbn()) = Shape({vec[0]}); } REGISTER_OP(OperatorConf::kSoftmaxConf, SoftmaxOp); diff --git a/oneflow/core/operator/softmax_op_test.cpp b/oneflow/core/operator/softmax_op_test.cpp index 6c9377c19ec95e725701643c8ebf5065bc2a48c0..0c52bbe21df96d59783f2ecba03245e8671b806a 100644 --- a/oneflow/core/operator/softmax_op_test.cpp +++ b/oneflow/core/operator/softmax_op_test.cpp @@ -2,17 +2,17 @@ namespace oneflow { -TEST(SoftmaxOp, softmax_3x4x5) { +TEST(SoftmaxOp, softmax_3x5) { // create softmax_op OperatorConf op_conf; op_conf.set_name("softmax_test"); - op_conf.mutable_softmax_conf()->set_axis(1); op_conf.mutable_softmax_conf()->set_in("softmax/in"); op_conf.mutable_softmax_conf()->set_out("softmax/out"); auto softmax_op = OpMgr::Singleton()->ConstructOp(op_conf); HashMap bn2shape_ptr{ - {softmax_op->SoleIbn(), new Shape({3, 4, 5})}, - {softmax_op->SoleObn(), new Shape}}; + {softmax_op->SoleIbn(), new Shape({3, 5})}, + {softmax_op->SoleObn(), new Shape}, + {softmax_op->SoleDtbn(), new Shape}}; auto fp = [&bn2shape_ptr](const std::string& bn) { return bn2shape_ptr.at(bn); }; @@ -20,7 +20,9 @@ TEST(SoftmaxOp, softmax_3x4x5) { softmax_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1); // test Shape* output_shape_ptr = fp(softmax_op->SoleObn()); + Shape* tmp_max_shape_ptr = fp(softmax_op->SoleDtbn()); ASSERT_EQ(*output_shape_ptr, Shape({3, 5})); + ASSERT_EQ(*tmp_max_shape_ptr, Shape({3})); } } // namespace oneflow