未验证 提交 7a4151e5 编写于 作者: X Xiaoyu Xu 提交者: GitHub

Merge branch 'master' into fea/graph_op_debug

......@@ -18,13 +18,13 @@ function write_to_file_and_print {
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.09 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.95 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.06 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.94 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.93 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.91 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.83 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.82 | write_to_file_and_print
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct PartialFCSampleState : public AutoGradCaptureState {
bool requires_grad = false;
int32_t index_sampled_label = -1;
int32_t index_weight = -1;
};
class PartialFCSample : public OpExprGradFunction<PartialFCSampleState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> PartialFCSample::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> PartialFCSample::Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->index_sampled_label = ctx->SaveTensorForBackward(outputs.at(1)); // sampled_label
ctx->index_weight = ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> PartialFCSample::Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 3);
in_grads->resize(1);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& diff_sampled_weight = out_grads.at(2); // diff of sampled_weight
const auto& sampled_tensor = ctx->SavedTensors().at(ctx->index_sampled_label);
const auto& weight = ctx->SavedTensors().at(ctx->index_weight);
const auto& out_tensors_of_op0 = JUST(
functional::DistributedPariticalFCSampleDisableBoxing(diff_sampled_weight, sampled_tensor));
const auto& out_tensors_of_op1 = JUST(functional::UnsortedSegmentSumLike(
out_tensors_of_op0->at(0), out_tensors_of_op0->at(1), weight, 0));
in_grads->at(0) = out_tensors_of_op1;
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("distributed_partial_fc_sample", PartialFCSample);
} // namespace one
} // namespace oneflow
......@@ -1511,6 +1511,17 @@
signature: "TensorTuple (Tensor log_probs, Tensor input_lengths, Bool merge_repeated=True) => CtcGreedyDecoder"
bind_python: True
- name: "distributed_partial_fc_sample"
signature:
"TensorTuple (Tensor weight, Tensor label, Int64 num_sample) => DistributedPariticalFCSample"
bind_python: True
- name: "distributed_partial_fc_sample_disable_boxing"
signature:
"TensorTuple (Tensor sampled_weight_diff, Tensor sampled_label) => DistributedPariticalFCSampleDisableBoxing"
bind_python: False
- name: "meshgrid"
signature: "TensorTuple (TensorTuple tensors) => Meshgrid"
bind_python: True
......@@ -1872,6 +1872,48 @@ class CtcGreedyDecoderFunctor {
std::shared_ptr<OpExpr> op_;
};
class PartialFCSampleFunctor {
public:
PartialFCSampleFunctor() {
op_ = CHECK_JUST(one::OpBuilder("distributed_partial_fc_sample")
.Input("weight")
.Input("label")
.Output("mapped_label")
.Output("sampled_label")
.Output("sampled_weight")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& wegiht,
const std::shared_ptr<one::Tensor>& label,
const int64_t& num_sample) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("num_sample", num_sample));
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {wegiht, label}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class PariticalFCSampleDisableBoxing {
public:
PariticalFCSampleDisableBoxing() {
op_ = CHECK_JUST(one::OpBuilder("distributed_partial_fc_sample_disable_boxing")
.Input("sampled_weight_diff")
.Input("sampled_label")
.Output("boxing_disabled_sampled_weight_diff")
.Output("boxing_disabled_sampled_label")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& sampled_weight_diff,
const std::shared_ptr<one::Tensor>& sampled_label) const {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {sampled_weight_diff, sampled_label});
}
private:
std::shared_ptr<OpExpr> op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -1932,6 +1974,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::FusedBiasAddDropoutFunctor>("FusedBiasAddDropout");
m.add_functor<impl::FusedScaleTrilFunctor>("FusedScaleTril");
m.add_functor<impl::CtcGreedyDecoderFunctor>("CtcGreedyDecoder");
m.add_functor<impl::PartialFCSampleFunctor>("DistributedPariticalFCSample");
m.add_functor<impl::PariticalFCSampleDisableBoxing>("DistributedPariticalFCSampleDisableBoxing");
};
} // namespace functional
......
......@@ -152,7 +152,10 @@ class DistributedPartialFcSampleOpKernelState final : public user_op::OpKernelSt
SetupKernel<<<BlocksNum4ThreadsNum(num_classes), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(seed, curand_states_);
}
~DistributedPartialFcSampleOpKernelState() { OF_CUDA_CHECK(cudaFree(curand_states_)); };
~DistributedPartialFcSampleOpKernelState() {
cudaError_t ret = cudaFree(curand_states_);
if (ret != cudaErrorCudartUnloading) { OF_CUDA_CHECK(ret); }
};
int64_t lower() const { return lower_; }
int64_t upper() const { return upper_; }
......
......@@ -128,6 +128,7 @@ from oneflow._C import softplus
from oneflow._C import tril
from oneflow._C import triu
from oneflow._C import pad
from oneflow._C import distributed_partial_fc_sample
from oneflow._C import transpose
from oneflow._C import relu
from oneflow._C import softmax
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestParitalFC(flow.unittest.TestCase):
def test_parital_fc(test_case):
p = flow.env.all_device_placement("cuda")
w = flow.randn(50000, 128, placement=p, sbp=flow.sbp.broadcast)
label = flow.randint(0, 50000, (512,), placement=p, sbp=flow.sbp.broadcast)
num_sample = 5000
out = flow.distributed_partial_fc_sample(w, label, num_sample)
test_case.assertTrue(out[0].shape == flow.Size([512]))
test_case.assertTrue(out[1].shape == flow.Size([5000]))
test_case.assertTrue(out[2].shape == flow.Size([5000, 128]))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册