未验证 提交 c3151903 编写于 作者: Y yuyang18

Merge branch 'squeeze_op' of https://github.com/chenwhql/Paddle into pr/11812

...@@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d ...@@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
......
...@@ -83,18 +83,20 @@ else() ...@@ -83,18 +83,20 @@ else()
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
endif() endif()
find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS if(WITH_SYSTEM_BLAS)
find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS}) ${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS})
find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
${REFERENCE_CBLAS_LIB_SEARCH_PATHS}) ${REFERENCE_CBLAS_LIB_SEARCH_PATHS})
if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
set(CBLAS_FOUND ON) set(CBLAS_FOUND ON)
set(CBLAS_PROVIDER REFERENCE) set(CBLAS_PROVIDER REFERENCE)
set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR})
set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY}) set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY})
add_definitions(-DPADDLE_USE_REFERENCE_CBLAS) add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
endif()
endif() endif()
if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND) if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND)
......
...@@ -7,18 +7,18 @@ if(NOT WITH_FLUID_ONLY) ...@@ -7,18 +7,18 @@ if(NOT WITH_FLUID_ONLY)
add_subdirectory(legacy/parameter) add_subdirectory(legacy/parameter)
if(MOBILE_INFERENCE) if(MOBILE_INFERENCE)
add_subdirectory(capi) add_subdirectory(legacy/capi)
else() else()
add_subdirectory(legacy/pserver) add_subdirectory(legacy/pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
if(WITH_C_API) if(WITH_C_API)
add_subdirectory(capi) add_subdirectory(legacy/capi)
endif() endif()
if(WITH_SWIG_PY) if(WITH_SWIG_PY)
add_subdirectory(api) add_subdirectory(legacy/api)
endif() endif()
endif() endif()
endif() endif()
......
...@@ -46,9 +46,16 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -46,9 +46,16 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() { void Executor::BeginPass() {
::paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>() ::paddle::operators::distributed::RPCClient::GetInstance<
->SendComplete(); ::paddle::operators::distributed::GRPCClient>()
->SendBeginPass();
}
void Executor::EndPass() {
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
} }
#endif #endif
......
...@@ -46,9 +46,14 @@ class Executor { ...@@ -46,9 +46,14 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
/* /*
* Sending signal to pserver to mark current trainer stop. * Sending signal to pserver to mark current pass started.
*/ */
void Complete(); void BeginPass();
/*
* Sending signal to pserver to mark current pass finished.
*/
void EndPass();
#endif #endif
/* @Brief /* @Brief
......
...@@ -35,10 +35,20 @@ void GRPCClient::InitEventLoop() { ...@@ -35,10 +35,20 @@ void GRPCClient::InitEventLoop() {
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
void GRPCClient::SendComplete() { void GRPCClient::SendBeginPass() {
for (auto& it : channels_) { for (auto& it : channels_) {
this->AsyncSendComplete(it.first); VLOG(3) << "send begin pass to: " << it.first;
this->AsyncSendBeginPass(it.first);
} }
this->Wait();
}
void GRPCClient::SendEndPass() {
for (auto& it : channels_) {
VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first);
}
this->Wait();
} }
GRPCClient::~GRPCClient() { GRPCClient::~GRPCClient() {
...@@ -226,19 +236,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -226,19 +236,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendBeginPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(BEGIN_PASS_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
void GRPCClient::AsyncSendEndPass(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(END_PASS_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
......
...@@ -77,11 +77,12 @@ class BaseProcessor { ...@@ -77,11 +77,12 @@ class BaseProcessor {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info; var_h_ = var_info;
context_->set_wait_for_ready(true); context_->set_wait_for_ready(true);
if (time_out) {
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); std::chrono::system_clock::now() +
std::chrono::milliseconds(time_out);
context_->set_deadline(deadline); context_->set_deadline(deadline);
}
} }
virtual void Prepare(int64_t time_out) { virtual void Prepare(int64_t time_out) {
...@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient { ...@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBeginPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; void Wait() override;
void SendComplete() override; void SendBeginPass() override;
void SendEndPass() override;
protected: protected:
void InitImpl() override; void InitImpl() override;
...@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient { ...@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private: private:
......
...@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend"; ...@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) { } else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv complete message"; VLOG(3) << "sync: recv begin pass message";
rpc_server_->DecreaseClientNum(); rpc_server_->WaitCond(kRequestSend);
rpc_server_->BeginPass();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) { rpc_server_->WaitCond(kRequestSend);
rpc_server_->WaitCond(kRequestSend); VLOG(3) << "sync: processing received var: " << varname;
}
if (invar == nullptr) { if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname; LOG(ERROR) << "sync: Can not find server side var: " << varname;
...@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Variable** outvar, framework::Variable** outvar,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) {
if (varname != FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
if (sync_mode_) { VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else if (varname == END_PASS_MESSAGE) {
rpc_server_->EndPass();
} else {
rpc_server_->WaitCond(kRequestGet); rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != END_PASS_MESSAGE) {
*outvar = scope_->FindVar(varname);
} }
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} }
return true; return true;
} }
......
...@@ -60,10 +60,17 @@ class RPCClient { ...@@ -60,10 +60,17 @@ class RPCClient {
const std::string& dir, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data virtual void AsyncSendBeginPass(const std::string& ep,
// to train, so that the pserver can reduce it's barrier count, and continue int64_t time_out = FLAGS_rpc_deadline) = 0;
// to train with other trainers.
virtual void SendComplete() = 0; virtual void AsyncSendEndPass(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers.
virtual void SendBeginPass() = 0;
virtual void SendEndPass() = 0;
virtual void Wait() = 0; virtual void Wait() = 0;
......
...@@ -44,7 +44,8 @@ void RPCServer::SavePort() const { ...@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] { barrier_cond_.wait(lock, [this, &rpc_name] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_: " << rpc_name << " " VLOG(3) << "batch_barrier_: " << rpc_name << " "
...@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { ...@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
} }
} }
void RPCServer::DecreaseClientNum() { void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_++;
VLOG(4) << "increase client_num to: " << client_num_;
}
barrier_cond_.notify_all();
}
void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier";
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_--; client_num_--;
VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
}
} }
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
......
...@@ -43,6 +43,9 @@ class RPCServer { ...@@ -43,6 +43,9 @@ class RPCServer {
bool IsExit() { return exit_flag_.load(); } bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; } int GetSelectedPort() const { return selected_port_; }
int GetClientNum() const;
void SavePort() const; void SavePort() const;
// RegisterRPC, register the rpc method name to a handler // RegisterRPC, register the rpc method name to a handler
...@@ -60,7 +63,10 @@ class RPCServer { ...@@ -60,7 +63,10 @@ class RPCServer {
void SetCond(const std::string& rpc_name); void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name); void IncreaseBatchBarrier(const std::string rpc_name);
void DecreaseClientNum();
void BeginPass();
void EndPass();
void ResetBarrierCounter(); void ResetBarrierCounter();
protected: protected:
......
...@@ -28,10 +28,10 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -28,10 +28,10 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
"Output(Out) of SqueezeOp should not be null."); "Output(Out) of SqueezeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X"); const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<9). // Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 9, PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, dynamic dimensions must have " "Invalid dimnesions, dynamic dimensions must have "
"between [1, 9] dimensions."); "between [1, 6] dimensions (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes"); const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) { for (int a : axes) {
...@@ -41,7 +41,6 @@ class SqueezeOpInferShape : public framework::InferShapeBase { ...@@ -41,7 +41,6 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
auto out_dims = GetOutputShape(axes, x_dims); auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
// TODO(chenweihang): This share option is necessary?
if (x_dims[0] == out_dims[0]) { if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
// are the same. // are the same.
......
...@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
.def("complete", &Executor::Complete) .def("begin_pass", &Executor::BeginPass)
.def("end_pass", &Executor::EndPass)
#endif #endif
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars) { int block_id, bool create_local_scope, bool create_vars) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
%include "std_string.i" %include "std_string.i"
%{ %{
#define SWIG_FILE_WITH_INIT #define SWIG_FILE_WITH_INIT
#include "api/PaddleAPI.h" #include "legacy/api/PaddleAPI.h"
%} %}
%include "exception.i" %include "exception.i"
...@@ -199,4 +199,4 @@ namespace std { ...@@ -199,4 +199,4 @@ namespace std {
%ignore OptimizationConfigPrivate; %ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate; %ignore ParameterTraverseCallbackPrivate;
%include "utils/GlobalConstants.h" %include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h" %include "legacy/api/PaddleAPI.h"
...@@ -348,6 +348,12 @@ class Executor(object): ...@@ -348,6 +348,12 @@ class Executor(object):
] ]
return outs return outs
def begin_pass(self):
self.executor.begin_pass()
def end_pass(self):
self.executor.end_pass()
def run(self, def run(self,
program=None, program=None,
feed=None, feed=None,
......
...@@ -27,7 +27,7 @@ class TestSqueezeOp1(OpTest): ...@@ -27,7 +27,7 @@ class TestSqueezeOp1(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -46,7 +46,7 @@ class TestSqueezeOp2(OpTest): ...@@ -46,7 +46,7 @@ class TestSqueezeOp2(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -65,7 +65,7 @@ class TestSqueezeOp3(OpTest): ...@@ -65,7 +65,7 @@ class TestSqueezeOp3(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -78,13 +78,13 @@ class TestSqueezeOp3(OpTest): ...@@ -78,13 +78,13 @@ class TestSqueezeOp3(OpTest):
# Correct: Just part of axes be squeezed. # Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest): class TestSqueezeOp4(OpTest):
def setUp(self): def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1) ori_shape = (3, 1, 5, 1, 4, 1)
axes = (2, 6) axes = (1, -1)
new_shape = (1, 3, 5, 1, 4) new_shape = (3, 5, 1, 4)
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": False} self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -122,7 +122,7 @@ class TestSqueezeOpInplace2(OpTest): ...@@ -122,7 +122,7 @@ class TestSqueezeOpInplace2(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -141,7 +141,7 @@ class TestSqueezeOpInplace3(OpTest): ...@@ -141,7 +141,7 @@ class TestSqueezeOpInplace3(OpTest):
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
...@@ -154,13 +154,13 @@ class TestSqueezeOpInplace3(OpTest): ...@@ -154,13 +154,13 @@ class TestSqueezeOpInplace3(OpTest):
# Correct: Inpalce. Just part of axes be squeezed. # Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest): class TestSqueezeOpInplace4(OpTest):
def setUp(self): def setUp(self):
ori_shape = (1, 3, 1, 5, 1, 4, 1) ori_shape = (3, 1, 5, 1, 4, 1)
axes = (2, 6) axes = (1, -1)
new_shape = (1, 3, 5, 1, 4) new_shape = (3, 5, 1, 4)
self.op_type = "squeeze" self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inpalce": True} self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册