未验证 提交 d0684930 编写于 作者: G gongweibao 提交者: GitHub

parameter dispather. (#12666)

上级 efc5392d
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
DEFINE_int32(io_threadpool_size, 100, DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100"); "number of threads used for doing IO, default 100");
DEFINE_int32(dist_threadpool_size, 0,
"number of threads used for distributed executed.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,6 +38,10 @@ void ThreadPool::Init() { ...@@ -35,6 +38,10 @@ void ThreadPool::Init() {
if (threadpool_.get() == nullptr) { if (threadpool_.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number // TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
if (FLAGS_dist_threadpool_size > 0) {
num_threads = FLAGS_dist_threadpool_size;
VLOG(1) << "set dist_threadpool_size to " << num_threads;
}
PADDLE_ENFORCE_GT(num_threads, 0); PADDLE_ENFORCE_GT(num_threads, 0);
threadpool_.reset(new ThreadPool(num_threads)); threadpool_.reset(new ThreadPool(num_threads));
} }
......
...@@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField( ...@@ -190,12 +190,15 @@ bool VariableResponse::ProcSerializedField(
#endif #endif
} }
VLOG(7) << "ProcSerializedField:" << meta_.varname()
<< ", type:" << meta_.type() << std::endl;
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!"); PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) { if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
return false; return false;
} }
return true; return true;
} }
...@@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField( ...@@ -206,7 +209,9 @@ bool VariableResponse::ProcSerializedField(
return true; return true;
} }
return true; PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());
return false;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -123,8 +123,11 @@ void ListenAndServOp::RunSyncLoop( ...@@ -123,8 +123,11 @@ void ListenAndServOp::RunSyncLoop(
optimize_prepared.begin(), optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
// Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
while (true) { while (true) {
rpc_service_->Profiler().OneStep(); rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
......
...@@ -122,7 +122,7 @@ def __bootstrap__(): ...@@ -122,7 +122,7 @@ def __bootstrap__():
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
'cpu_deterministic' "dist_threadpool_size", 'cpu_deterministic'
] ]
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from . import framework from . import framework
import numpy as np import numpy as np
import contextlib import contextlib
from .framework import convert_np_dtype_to_dtype_
from .core import VarDesc from .core import VarDesc
__all__ = [ __all__ = [
......
...@@ -59,8 +59,8 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$ ...@@ -59,8 +59,8 @@ py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=$
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL) py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20) set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180) set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 180) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
endif() endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
......
...@@ -26,6 +26,12 @@ from paddle.fluid.layers.io import ListenAndServ ...@@ -26,6 +26,12 @@ from paddle.fluid.layers.io import ListenAndServ
from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Recv
from paddle.fluid.layers.io import Send from paddle.fluid.layers.io import Send
from paddle.fluid import core
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class TestSendOp(unittest.TestCase): class TestSendOp(unittest.TestCase):
def test_send(self): def test_send(self):
...@@ -89,18 +95,29 @@ class TestSendOp(unittest.TestCase): ...@@ -89,18 +95,29 @@ class TestSendOp(unittest.TestCase):
def init_client(self, place, port): def init_client(self, place, port):
main = fluid.Program() main = fluid.Program()
with fluid.program_guard(main): with fluid.program_guard(main):
main.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={},
attrs={
"endpoints": ["127.0.0.1:{0}".format(port)],
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
x = layers.data( x = layers.data(
shape=[32, 32], shape=[32, 32],
dtype='float32', dtype='float32',
name='X', name='X',
append_batch_size=False) append_batch_size=False)
fluid.initializer.Constant(value=2.3)(x, main.global_block()) fluid.initializer.Constant(value=2.3)(x, main.global_block())
get_var = main.global_block().create_var( get_var = main.global_block().create_var(
name="scale_0.tmp_0", # server side var name="scale_0.tmp_0", # server side var
dtype="float32", dtype="float32",
persistable=False, persistable=False,
shape=[32, 32]) shape=[32, 32])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block()) fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
Send("127.0.0.1:%d" % port, [x]) Send("127.0.0.1:%d" % port, [x])
o = Recv("127.0.0.1:%d" % port, [get_var]) o = Recv("127.0.0.1:%d" % port, [get_var])
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import delete_ops from paddle.fluid.transpiler.distribute_transpiler import delete_ops
import traceback import traceback
import collections
class TranspilerTest(unittest.TestCase): class TranspilerTest(unittest.TestCase):
...@@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase): ...@@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase):
self.origin_prog = main.clone() self.origin_prog = main.clone()
return main return main
def get_trainer(self, config=None, sync_mode=True): def get_trainer(self, config=None):
t = self._transpiler_instance(config, sync_mode) src = fluid.default_startup_program().clone()
return t.get_trainer_program()
t = self._transpiler_instance(config)
trainer_main = t.get_trainer_program()
trainer_startup = fluid.default_startup_program()
assert (src.num_blocks == 1)
assert (trainer_startup.num_blocks == src.num_blocks)
return trainer_main, trainer_startup
def get_pserver(self, ep, config=None, sync_mode=True): def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode) t = self._transpiler_instance(config, sync_mode)
...@@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest): ...@@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
pserver2, startup2 = self.get_pserver(self.pserver2_ep) pserver2, startup2 = self.get_pserver(self.pserver2_ep)
trainer = self.get_trainer() trainer, trainer_startup = self.get_trainer()
# splited var blocks should be in startup program
self.assertTrue("fc_w.block0" in trainer_startup.global_block().vars)
self.assertTrue("fc_w.block1" in trainer_startup.global_block().vars)
self.assertTrue("fc_w" in trainer_startup.global_block().vars)
self.assertTrue("fc_b" in trainer_startup.global_block().vars)
self.assertTrue("fc_w@GRAD" not in trainer_startup.global_block().vars)
self.assertTrue("fc_b@GRAD" not in trainer_startup.global_block().vars)
src = [op.type for op in trainer_startup.global_block().ops]
dst = ['fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', \
'fetch_barrier', 'concat']
self.assertEqual(src, dst)
self.assertEqual([op.type for op in trainer.global_block().ops], [ self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
...@@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest): ...@@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest):
pserver, startup = self.get_pserver(self.pserver1_ep, config) pserver, startup = self.get_pserver(self.pserver1_ep, config)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config) pserver2, startup2 = self.get_pserver(self.pserver2_ep, config)
trainer = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual([op.type for op in trainer.global_block().ops], [ self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
...@@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest): ...@@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(pserver.blocks), 4) self.assertEqual(len(pserver.blocks), 4)
lr_decay_ops = [op.type for op in pserver.blocks[1].ops] lr_decay_ops = [op.type for op in pserver.blocks[1].ops]
...@@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest): ...@@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer, _ = self.get_trainer()
serv_op = pserver.blocks[0].ops[0] serv_op = pserver.blocks[0].ops[0]
sub_blocks = [] sub_blocks = []
...@@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest): ...@@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(pserver.blocks), 3) self.assertEqual(len(pserver.blocks), 3)
self.assertEqual([op.type for op in pserver.blocks[1].ops], self.assertEqual([op.type for op in pserver.blocks[1].ops],
...@@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
trainer = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(pserver.blocks), 9) self.assertEqual(len(pserver.blocks), 9)
self.assertEqual([op.type for op in pserver.blocks[1].ops], [ self.assertEqual([op.type for op in pserver.blocks[1].ops], [
...@@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase): ...@@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in pserver1.blocks[2].ops], self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["sum", "adam", "scale", "scale"]) ["sum", "adam", "scale", "scale"])
trainer = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
...@@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
# 5 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer = self.get_trainer() trainer, _ = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
...@@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ...@@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
self.assertEqual([op.type for op in pserver1.blocks[2].ops], self.assertEqual([op.type for op in pserver1.blocks[2].ops],
["adam", "scale", "scale"]) ["adam", "scale", "scale"])
trainer = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
...@@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 5 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
......
...@@ -195,6 +195,9 @@ class DistributeTranspiler(object): ...@@ -195,6 +195,9 @@ class DistributeTranspiler(object):
if program is None: if program is None:
program = default_main_program() program = default_main_program()
self.origin_program = program self.origin_program = program
self.origin_startup_program = default_startup_program().clone()
self.startup_program = default_startup_program()
self.trainer_num = trainers self.trainer_num = trainers
self.sync_mode = sync_mode self.sync_mode = sync_mode
self.trainer_id = trainer_id self.trainer_id = trainer_id
...@@ -205,10 +208,10 @@ class DistributeTranspiler(object): ...@@ -205,10 +208,10 @@ class DistributeTranspiler(object):
ps_dispatcher = self.config.split_method(self.pserver_endpoints) ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use. # step 1: split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars() self._init_splited_vars()
# step 3.1: insert send op to send gradient vars to parameter servers # step 2: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
send_vars = [] send_vars = []
...@@ -265,7 +268,7 @@ class DistributeTranspiler(object): ...@@ -265,7 +268,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
# step 3.2: insert recv op to receive parameters from parameter server # step 3: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
for _, var in enumerate(send_vars): for _, var in enumerate(send_vars):
recv_vars.append(self.grad_param_mapping[var]) recv_vars.append(self.grad_param_mapping[var])
...@@ -312,6 +315,8 @@ class DistributeTranspiler(object): ...@@ -312,6 +315,8 @@ class DistributeTranspiler(object):
outputs={"Out": [orig_param]}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, self._replace_lookup_table_op_with_prefetch(program,
pserver_endpoints) pserver_endpoints)
...@@ -328,8 +333,78 @@ class DistributeTranspiler(object): ...@@ -328,8 +333,78 @@ class DistributeTranspiler(object):
# FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay?
delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
self.origin_program.__str__() self.origin_program.__str__()
return self.origin_program return self.origin_program
def _get_trainer_startup_program(self,
recv_vars,
eplist,
startup_program=None):
"""
Get transpiled trainer side startup program.
Args:
startup_program(Program): Startup program.
Returns:
Program: trainer side startup program.
"""
if startup_program is None:
startup_program = self.startup_program
# FIXME(gongwb): delete not need ops.
# note that: some parameter is not trainable and those ops can't be deleted.
for varname, splited_var in self.param_var_mapping.iteritems():
# Get the eplist of recv vars
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
for var in splited_var:
if startup_program.global_block().has_var(var.name):
continue
startup_program.global_block().create_var(
name=var.name,
persistable=False,
type=var.type,
dtype=var.dtype,
shape=var.shape,
lod_level=var.lod_level)
op = startup_program.global_block().append_op(
type="recv",
inputs={},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
startup_program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={},
attrs={
"endpoints": self.pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in self.param_var_mapping.iteritems():
#add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1:
continue
orig_param = startup_program.global_block().vars[varname]
startup_program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
outputs={"Out": [orig_param]},
attrs={"axis": 0})
return startup_program
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
""" """
Get parameter server side program. Get parameter server side program.
...@@ -576,14 +651,16 @@ class DistributeTranspiler(object): ...@@ -576,14 +651,16 @@ class DistributeTranspiler(object):
new_outputs = dict() new_outputs = dict()
# do not append startup op if var is not on this pserver # do not append startup op if var is not on this pserver
op_on_pserver = False op_on_pserver = False
for key in op.output_names: # TODO(gongwb): remove this line.
newname, _ = _get_splited_name_and_shape(op.output(key)[0]) if op.type not in ["recv", "fetch_barrier", "concat"]:
if newname: for key in op.output_names:
op_on_pserver = True newname, _ = _get_splited_name_and_shape(op.output(key)[0])
new_outputs[key] = created_var_map[newname] if newname:
elif op.output(key)[0] in pserver_vars: op_on_pserver = True
op_on_pserver = True new_outputs[key] = created_var_map[newname]
new_outputs[key] = pserver_vars[op.output(key)[0]] elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[op.output(key)[0]]
if op_on_pserver: if op_on_pserver:
# most startup program ops have no inputs # most startup program ops have no inputs
...@@ -1022,7 +1099,6 @@ class DistributeTranspiler(object): ...@@ -1022,7 +1099,6 @@ class DistributeTranspiler(object):
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(orig_var.name)] [program.global_block().var(orig_var.name)]
continue continue
var_mapping[varname] = [] var_mapping[varname] = []
orig_shape = orig_var.shape orig_shape = orig_var.shape
orig_dim1_flatten = 1 orig_dim1_flatten = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册