提交 ab58fb90 编写于 作者: W Will Zhang 提交者: GitHub

Merge pull request #182 from Oneflow-Inc/dev_chengcheng

fix wrong design of softmax_op
......@@ -128,7 +128,6 @@ message ReluOpConf {
message SoftmaxOpConf {
string in = 1;
string out = 2;
int32 axis = 3;
}
message MultinomialLogisticLossOpConf {
......
......@@ -8,6 +8,7 @@ void SoftmaxOp::InitFromOpConf(const OperatorConf& op_conf) {
EnrollInputBn("in");
EnrollOutputBn("out");
EnrollDataTmpBn("tmp_max");
}
const PbMessage& SoftmaxOp::GetSpecialConf() const {
......@@ -18,10 +19,9 @@ void SoftmaxOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
std::vector<int64_t> vec = GetShapePtr4BnInOp(SoleIbn())->dim_vec();
CHECK_GT(vec.size(), 1);
int32_t axis = (op_conf().softmax_conf().axis() + vec.size()) % vec.size();
vec.erase(vec.begin() + axis);
CHECK_GE(vec.size(), 2);
*GetShapePtr4BnInOp(SoleObn()) = Shape(vec);
*GetShapePtr4BnInOp(SoleDtbn()) = Shape({vec[0]});
}
REGISTER_OP(OperatorConf::kSoftmaxConf, SoftmaxOp);
......
......@@ -2,17 +2,17 @@
namespace oneflow {
TEST(SoftmaxOp, softmax_3x4x5) {
TEST(SoftmaxOp, softmax_3x5) {
// create softmax_op
OperatorConf op_conf;
op_conf.set_name("softmax_test");
op_conf.mutable_softmax_conf()->set_axis(1);
op_conf.mutable_softmax_conf()->set_in("softmax/in");
op_conf.mutable_softmax_conf()->set_out("softmax/out");
auto softmax_op = OpMgr::Singleton()->ConstructOp(op_conf);
HashMap<std::string, Shape*> bn2shape_ptr{
{softmax_op->SoleIbn(), new Shape({3, 4, 5})},
{softmax_op->SoleObn(), new Shape}};
{softmax_op->SoleIbn(), new Shape({3, 5})},
{softmax_op->SoleObn(), new Shape},
{softmax_op->SoleDtbn(), new Shape}};
auto fp = [&bn2shape_ptr](const std::string& bn) {
return bn2shape_ptr.at(bn);
};
......@@ -20,7 +20,9 @@ TEST(SoftmaxOp, softmax_3x4x5) {
softmax_op->InferShape4FwBlobs(fp, kDataParallel, 0, 1);
// test
Shape* output_shape_ptr = fp(softmax_op->SoleObn());
Shape* tmp_max_shape_ptr = fp(softmax_op->SoleDtbn());
ASSERT_EQ(*output_shape_ptr, Shape({3, 5}));
ASSERT_EQ(*tmp_max_shape_ptr, Shape({3}));
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册