提交 9de096bd 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_cpu_pe

......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
......@@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
}
std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr;
try {
fetch_data = underlying_executor_->Run(fetch_tensors);
} catch (...) {
eptr = std::current_exception();
}
auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
......@@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
scope->DeleteScope(local_scope);
}
}
return fetch_data;
if (eptr) {
std::rethrow_exception(eptr);
} else {
return fetch_data;
}
}
} // namespace details
} // namespace framework
......
......@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
set.clear();
};
// Clean run context
run_op_futures_.clear();
exception_.reset();
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
......@@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_);
std::unique_lock<std::mutex> l(exception_mu_);
if (exception_) {
l.unlock();
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
l.lock();
std::exception *exp = exception_.get();
if (dynamic_cast<platform::EOFException *>(exp)) {
auto e = *static_cast<platform::EOFException *>(exp);
exception_.reset();
throw e;
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
auto e = *static_cast<platform::EnforceNotMet *>(exp);
exception_.reset();
throw e;
} else {
LOG(FATAL) << "Unknown exception.";
......@@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp(
}
};
if (pool_) {
pool_->enqueue(op_run);
run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <deque>
#include <list>
#include <string>
#include <unordered_set>
#include <utility>
......@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};
} // namespace details
......
......@@ -259,12 +259,15 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <vector>
namespace paddle {
namespace operators {
/**
* Organize the classes into a binary tree. At each node, a sigmoid function
* is used to calculate the probability of belonging to the right branch.
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
* Hierarchical Probabilistic Neural Network Language Model."
*
* Here we uses a simple way of making the binary tree.
* Assuming the number of classes C = 6,
* The classes are organized as a binary tree in the following way:
*
* @code{.py}
* *-*-*- 2
* | | |- 3
* | |
* | |-*- 4
* | |- 5
* |
* |-*- 0
* |- 1
* @endcode
*
* where * indicates an internal node, and each leaf node represents a class.
* - Node 0 ... C-2 are internal nodes.
* - Node C-1 ... 2C-2 are leaf nodes.
* - Class c is represented by leaf node \f$c+C-1\f$.
*
* We assign an id for each node:
* - the id of root be 0.
* - the left child of a node i is 2*i+1.
* - the right child of a node i is 2*i+2.
*
* It's easy to see that:
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
* - the j-th level ancestor of node i is
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
*
*/
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
"Output(PreOut) should not be null.");
const int64_t batch_size = ctx->GetInputDim("X")[0];
std::vector<int64_t> output_shape({batch_size, 1});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
template <typename AttrType>
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, required) The input tensor with shape [N, D], "
"where N is the size of mini-batch, and D is the feature size.");
AddInput("W",
"(Tensor, required), The parameters of hierarchical "
"sigmoid operator, each of them is a 2-D tensor, the shape is"
"[num_classes - 1, D].");
AddInput("Label",
"(Tensor, required), The labels of training data. It's a"
"tensor with shape [N, 1].");
AddInput("Bias",
"(Tensor, optional), The bias is a tensor with shape"
"[1, num_classes - 1].");
AddOutput("Out",
"(Tensor, required) The output of hierarchical sigmoid operator."
"The shape is [N, 1].");
AddOutput("PreOut",
"(Tensor, required) A intermedia 2-D tensor with shape "
"[batch_size, code_length], where code_length represents the "
"maximum path length from root to leaf nodes.")
.AsIntermediate();
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
.SetDefault(2);
AddComment(R"DOC(
The hierarchical sigmoid operator organize the classes into a binary tree.
At each node, a sigmoid function is used to calculate the probability of
belonging to the right branch. This idea is from
"F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
)DOC");
}
};
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
"Input(Preout) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
"Output(W@Grad should not be null.)");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* label = ctx.Input<framework::Tensor>("Label");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* out = ctx.Output<framework::Tensor>("Out");
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
int64_t code_length = math::FindLastSet(num_classes - 1);
int64_t batch_size = in->dims()[0];
framework::Tensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
math::RowwiseSum<DeviceContext, T> row_sum;
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out);
if (bias) {
bit_code.Add(pre_out, *bias);
}
bit_code.Mul(pre_out, *w, *in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* w = ctx.Input<framework::Tensor>("W");
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
auto* bias_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
auto* label = ctx.Input<framework::Tensor>("Label");
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
w_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, in_grad, static_cast<T>(0.0));
zero(dev_ctx, w_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
// softrelu derivative
pre_out_grad_mat.device(place) =
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
pre_out_grad_mat.device(place) =
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code.AddGrad(pre_out_grad, bias_grad);
}
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
bit_code.MulGradError(pre_out_grad, *w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
......
......@@ -155,7 +155,7 @@ class RowwiseSum<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
PADDLE_ENFORCE_EQ(out->numel(), height);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include <iostream>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
const framework::Tensor& vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat->dims()[0];
size_t width = tmat->dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
framework::Tensor* vec) {
SimpleCodeTable code_table(num_classes_);
size_t batch_size = tmat.dims()[0];
size_t width = tmat.dims()[1];
for (size_t i = 0; i < batch_size; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
framework::Tensor* sum, T scale_sum) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t o_width = tmat.dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
T sm = static_cast<T>(0.0);
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
// calc_bit starts from right most bit, while data in tmat[i] is in the
// reverse order.
sm += tmat.data<T>()[i * o_width + j];
}
}
sum->data<T>()[i] = scale_sum * sm;
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
const framework::Tensor& weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t tmat_width = tmat->dims()[1];
size_t input_width = input.dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat->data<T>();
auto weight_value = weight.data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
T sum = static_cast<T>(0.0);
for (size_t k = 0; k < input_width; ++k) {
sum += weight_value[weight_width * index + k] *
input_value[input_width * i + k];
}
tmat_value[i * tmat_width + j] += sum;
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
framework::Tensor* weight,
const framework::Tensor& input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t input_width = input.dims()[1];
size_t tmat_width = tmat.dims()[1];
size_t weight_width = weight->dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight->data<T>();
auto input_value = input.data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
weight_value[weight_width * index + k] +=
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight,
framework::Tensor* input) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat.dims()[0];
size_t tmat_width = tmat.dims()[1];
size_t input_width = input->dims()[1];
size_t weight_width = weight.dims()[1];
auto tmat_value = tmat.data<T>();
auto weight_value = weight.data<T>();
auto input_value = input->data<T>();
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
size_t index = code.calc_index(j);
for (size_t k = 0; k < input_width; ++k) {
input_value[input_width * i + k] +=
tmat_value[i * tmat_width + j] *
weight_value[weight_width * index + k];
}
}
}
}
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
SimpleCodeTable code_table(num_classes_);
size_t num_samples = tmat->dims()[0];
size_t o_width = tmat->dims()[1];
for (size_t i = 0; i < num_samples; ++i) {
auto code = code_table(static_cast<size_t>(ids_[i]));
int code_length = code.get_length();
for (int j = 0; j < code_length; ++j) {
if (code.calc_bit(j)) {
tmat->data<T>()[i * o_width + j] -= 1;
}
}
}
}
template class MatrixBitCodeFunctor<float>;
template class MatrixBitCodeFunctor<double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/**
* SimpleCodeTable class should support 3 functions:
*
* size_t size()
* return the number of ids
*
* int get_max_code_length()
* return the maximal code length
*
* SimpleCode operator()(size_t i)
* return the i-th code. Code class is descriebed below.
*
* SimpleCode class should support 3 functions:
*
* int get_length()
* return the length of the code
*
* size_t cal_index(int bit)
* bit ranges from 0 to get_length() - 1
* return the index for the (1+bit) level parent
*
* bool calc_bit(int bit)
* return true if the bit level parent is the right child of (1+bit) level
* parent
*
*/
/**
* return the 1-based index of the highest bit set
*
* for x > 0:
* \f[
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
* \f]
*/
inline constexpr size_t FindLastSet(size_t x) {
return std::is_same<size_t, unsigned int>::value
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
: (std::is_same<size_t, unsigned long>::value // NOLINT
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
}
struct SimpleCode {
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
/**
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
* is `c + num_classes` and all siblings can get the same weight indice using
* prefixes.
* Weight index is the prefixes of encoding, thus leave out the right most
* bit in calc_index.
* Binary classification path is the suffixes of encoding, thus leave out the
* left most bit in calc_bit.
*/
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
inline int get_length() const { return FindLastSet(c_) - 1; }
private:
size_t c_;
};
struct SimpleCodeTable {
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
SimpleCode operator()(size_t code) const {
return SimpleCode(code, num_classes_);
}
size_t size() const { return num_classes_; }
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
private:
size_t num_classes_;
};
template <typename T>
class MatrixBitCodeFunctor {
public:
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
: num_classes_(num_classes), ids_(ids) {}
/* For j < code_length
tmat(i, j) += vec(0, index(i, j))
*/
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
/* For j < code_length
vec(0, index(i, j)) += tmat(i, j)
*/
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
/* For j < code_length
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
*/
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
/* For j < code_length
tmat(i, j) -= bit(i, j)
*/
void Sub(framework::Tensor* tmat);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
const framework::Tensor& input);
/* For index(i, j) >= 0:
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
*/
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
const framework::Tensor& input);
/* For j < code_length
input.row(i) += tmat(i, j) * weight.row(index(i, j))
*/
void MulGradError(const framework::Tensor& tmat,
const framework::Tensor& weight, framework::Tensor* input);
size_t num_classes_;
const int64_t* ids_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
batch_size_(static_cast<size_t>(batch_size)),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_);
}
......@@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
size_t batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
};
......@@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
for (size_t i = 0; i < batch_size_; ++i) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
if (buffer_.back().empty()) {
......@@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
size_t out_num = buffer_[0].size();
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
for (size_t j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
framework::DDim batch_shape = buffer_[0][j].dims();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class SqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SqueezeOp should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The squeeze axis should be less than input "
"tensor's rank.");
}
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
}
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
const framework::DDim &in_dims) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE(current >= 0,
"Invalid axis, the negative axis is out of range.");
PADDLE_ENFORCE(in_dims[current] == 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return framework::make_ddim(output_shape);
}
};
class SqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace",
"(default: false) Squeeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Squeeze Operator.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
and
axes = []
we get:
Out.shape = (3, 5)
)DOC");
}
};
class SqueezeGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim(framework::GradVarName("X"),
context->GetInputDim("X"));
context->ShareLoD("X", framework::GradVarName("X"));
}
};
class SqueezeGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
// Tell linker to use reshape op
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class UnsqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnsqueezeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnsqueezeOp should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X");
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE(x_dims.size() <= 6,
"Invalid dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)");
auto out_dims = GetOutputShape(axes, x_dims);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
}
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
const framework::DDim &in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE(output_size <= 6,
"The output tensor's rank should be less than 6.");
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE(
cur >= 0 && cur <= cur_output_size,
"The unsqueeze dims must be within range of current rank.");
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
return framework::make_ddim(output_shape);
}
};
class UnsqueezeOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &axes = Attr<std::vector<int>>("axes");
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
{{"Out", {Output("Out")}}}, attrs);
reshape_op->Run(scope, place);
}
};
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to be inserted")
.AddCustomChecker([](const std::vector<int> &axes) {
PADDLE_ENFORCE(!axes.empty(),
"Invalid axes, The unsqueeze axes is empty.");
// Validity Check: axes dims (<6).
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
"Invalid dimensions, dynamic dimensions should be "
"within [1, 6] dimensions (Eigen limit).");
// Validity Check: the range of unsqueeze aixs.
for (int axis : axes) {
PADDLE_ENFORCE(axis < 6,
"Invalid dimensions, input axis should be"
" within [1, 6] dimensions (Eigen limit).");
}
});
AddAttr<bool>(
"inplace",
"(default: false) Unsqueeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Unsqueeze Operator.
Insert single-dimensional entries to the shape of a tensor.
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Given a tensor such that tensor with shape [3, 4, 5],
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
)DOC");
}
};
class UnsqueezeGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
class UnsqueezeGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
attrs);
reshape_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
// Tell linker to use reshape op.
USE_OP(reshape);
namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradInferShape);
......@@ -66,6 +66,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DIST
return true;
#else
return false;
#endif
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -508,6 +516,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
......
......@@ -92,8 +92,15 @@ install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels
)
find_program(PATCHELF_EXECUTABLE patchelf)
if(NOT PATCHELF_EXECUTABLE)
message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.")
endif()
if(APPLE)
find_program(INSTALL_NAME_TOOL_EXECUTABLE install_name_tool)
if(NOT INSTALL_NAME_TOOL_EXECUTABLE)
message(FATAL_ERROR "install_name_tool not found, please check.\n")
endif()
else(APPLE)
find_program(PATCHELF_EXECUTABLE patchelf)
if(NOT PATCHELF_EXECUTABLE)
message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.")
endif()
endif(APPLE)
......@@ -121,6 +121,9 @@ def __bootstrap__():
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem'
]
if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline')
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import sys
__all__ = ['deprecated']
def deprecated(since, instead, extra_message=""):
def decorator(func):
err_msg = "API {0} is deprecated since {1}. Please use {2} instead.".format(
func.__name__, since, instead)
if len(extra_message) != 0:
err_msg += "\n"
err_msg += extra_message
@functools.wraps(func)
def wrapper(*args, **kwargs):
print >> sys.stderr, err_msg
return func(*args, **kwargs)
wrapper.__doc__ += "\n "
wrapper.__doc__ += err_msg
return wrapper
return decorator
......@@ -18,10 +18,7 @@ import collections
import copy
import unique_name
__all__ = [
'append_backward',
'calc_gradient',
]
__all__ = ['append_backward']
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
......
此差异已折叠。
......@@ -18,10 +18,12 @@ All util layers.
from layer_function_generator import autodoc
from ..framework import unique_name
from ..layer_helper import LayerHelper
from ..annotations import deprecated
__all__ = ['get_places']
__all__ = []
@deprecated(since='0.15.0', instead="ParallelExecutor")
@autodoc()
def get_places(device_count=None, device_type=None):
helper = LayerHelper('get_places', **locals())
......
......@@ -85,6 +85,7 @@ __all__ = [
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
......@@ -3871,6 +3872,74 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
complete binary tree, each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each
internal node on the path, and sum them to get a total cost. hsigmoid can
achive a acceleration from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the size of word dict.
Refer to `Hierarchical Probabilistic Neural Network Language Model
<http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_
Args:
input (Variable): The input tensor variable with shape
:math:`[N \\times D]`, where :math:`N` is the size of mini-batch,
and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[2], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='int64')
out = fluid.layers.hsigmoid(input=x, label=y, num_classes=6)
"""
helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
pre_out = helper.create_tmp_variable(dtype)
dim = input.shape[1]
if num_classes < 2:
raise ValueError("num_classes must not be less than 2.")
weights = helper.create_parameter(
attr=helper.param_attr,
shape=[num_classes - 1, dim],
is_bias=False,
dtype=input.dtype)
inputs = {"X": input, "W": weights, "Label": label}
if helper.bias_attr:
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=[1, num_classes - 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = bias
helper.append_op(
type="hierarchical_sigmoid",
inputs=inputs,
outputs={"Out": out,
"PreOut": pre_out},
attrs={"num_classes": num_classes})
return out
def transpose(x, perm, name=None):
"""
Permute the dimensions of `input` according to `perm`.
......
......@@ -29,7 +29,7 @@ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'Optimizer', 'RMSPropOptimizer'
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer'
]
......@@ -67,7 +67,7 @@ class Optimizer(object):
self._LARS_weight_decay = LARS_weight_decay
def _create_global_learning_rate(self):
lr = self.global_learning_rate()
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
......@@ -86,7 +86,7 @@ class Optimizer(object):
dtype='float32' if self._dtype == None else self._dtype,
persistable=True)
def global_learning_rate(self, program=None):
def _global_learning_rate(self, program=None):
"""
get global decayed learning rate
:return:
......@@ -110,9 +110,9 @@ class Optimizer(object):
return param_lr
else:
if param_lr == 1.0:
return self.global_learning_rate()
return self._global_learning_rate()
else:
return self.global_learning_rate() * param_lr
return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters
......@@ -185,10 +185,10 @@ class Optimizer(object):
format(name, param.name))
return self._accumulators[name][param.name]
def create_optimization_pass(self,
parameters_and_grads,
loss,
startup_program=None):
def _create_optimization_pass(self,
parameters_and_grads,
loss,
startup_program=None):
"""Add optimization operators to update gradients to variables.
Args:
......@@ -221,7 +221,7 @@ class Optimizer(object):
self._create_global_learning_rate()
if self._LARS_weight_decay > 0.0:
layers.append_LARS(parameters_and_grads,
self.global_learning_rate(),
self._global_learning_rate(),
self._LARS_weight_decay)
optimize_ops = []
......@@ -262,8 +262,8 @@ class Optimizer(object):
params_grads = append_regularization_ops(params_grads,
self.regularization)
optimize_ops = self.create_optimization_pass(params_grads, loss,
startup_program)
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
return optimize_ops, params_grads
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid.layers.device import get_places
import unittest
import paddle.fluid as fluid
import paddle
......@@ -144,7 +144,7 @@ def train(word_dict,
cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim)
else:
places = fluid.layers.get_places()
places = get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
cost, acc, _ = net_method(
......
......@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import paddle.fluid as fluid
import paddle
import sys
import numpy
import unittest
import math
import sys
import os
import sys
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
BATCH_SIZE = 64
......@@ -76,7 +78,7 @@ def train(nn_type,
net_conf = conv_net
if parallel:
places = fluid.layers.get_places()
places = get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
img_ = pd.read_input(img)
......
......@@ -14,6 +14,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
import unittest
import os
import numpy as np
......@@ -80,7 +81,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word])
else:
places = fluid.layers.get_places()
places = get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
avg_cost, predict_word = __network__(
......
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import math
import sys
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
# need to fix random seed and training data to compare the loss
# value accurately calculated by the default and the memory optimization
# version.
......@@ -34,7 +35,7 @@ if fluid.core.is_compiled_with_cuda():
use_nccl = False
place = fluid.CUDAPlace(0)
places = fluid.layers.get_places(device_count=0, device_type=device_type)
places = get_places(device_count=0, device_type=device_type)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
with pd.do():
x_ = pd.read_input(x)
......
......@@ -16,8 +16,6 @@ import unittest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
from paddle.fluid.backward import calc_gradient
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
self.trainer_id = 0
self.chief = self.trainer_id == 0
self.place = fluid.CPUPlace()
self.epoch_id = 100
self.step_id = 20
def test_checkpoint(self):
self.save_checkpoint()
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
self.assertTrue(serial >= 0)
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
exe = fluid.Executor(self.place)
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
self.epoch_interval, self.step_interval)
trainer_args = {}
trainer_args["epoch_id"] = self.epoch_id
trainer_args["step_id"] = self.step_id
program = fluid.Program()
with fluid.program_guard(program):
program.global_block().create_var(
name="scale_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
exe = fluid.Executor(self.place)
for i in xrange(10):
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
self.trainer_id, trainer_args, program,
config.max_num_checkpoints)
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
import decorators
import unittest
......@@ -20,7 +21,7 @@ import unittest
class TestGetPlaces(unittest.TestCase):
@decorators.prog_scope()
def test_get_places(self):
places = fluid.layers.get_places()
places = get_places()
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_main_program())
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
def find_latest_set(num):
return 1 + int(math.floor(math.log(num, 2)))
class CodeTable(object):
def __init__(self, num_classes, code):
self.c = num_classes + code
def cal_index(self, bit):
return (self.c >> (bit + 1)) - 1
def get_length(self):
return find_latest_set(self.c) - 1
def cal_bit(self, bit):
return self.c & (1 << bit)
def hsigmoid(x, w, label, bias, num_classes):
batch_size = x.shape[0]
code_length = find_latest_set(num_classes - 1)
code_table = [0 for _ in range(code_length)]
pre_output = np.zeros((batch_size, code_length))
pre_sum = np.zeros((batch_size, 1))
out = np.zeros((batch_size, 1)).astype("float32")
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += bias[0][idx]
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
for j in range(length):
idx = code_table.cal_index(j)
pre_output[i][j] += np.dot(w[idx], x[i])
# clip[-40.0, 40.0]
pre_output = np.clip(pre_output, -40.0, 40.0)
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
for i in range(batch_size):
code_table = CodeTable(num_classes, label[i])
length = code_table.get_length()
sum = 0.0
for j in range(length):
if code_table.cal_bit(j):
sum += pre_output[i][j]
out[i] = -1.0 * sum
# soft relu
pre_output = np.log(1 + np.exp(pre_output))
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
out += pre_sum
return pre_output, out
class TestHSigmoidOp(OpTest):
def setUp(self):
self.op_type = "hierarchical_sigmoid"
num_classes = 6
feature_size = 8
batch_size = 4
x = np.random.random((batch_size, feature_size)).astype("float32")
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
label = np.random.randint(0, num_classes, (batch_size, 1))
bias = np.random.random((1, num_classes - 1)).astype("float32")
self.attrs = {'num_classes': num_classes}
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
self.outputs = {'PreOut': pre_output, 'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import paddle.fluid.layers as layers
from paddle.fluid.layers.device import get_places
import paddle.fluid.nets as nets
from paddle.fluid.framework import Program, program_guard, default_main_program
from paddle.fluid.param_attr import ParamAttr
......@@ -173,6 +174,16 @@ class TestBook(unittest.TestCase):
x=dat, label=lbl))
print(str(program))
def test_hsigmoid(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[2], dtype='int64')
self.assertIsNotNone(
layers.hsigmoid(
input=x, label=y, num_classes=2))
print(str(program))
def test_sequence_expand(self):
program = Program()
with program_guard(program):
......@@ -238,7 +249,7 @@ class TestBook(unittest.TestCase):
def test_get_places(self):
program = Program()
with program_guard(program):
x = layers.get_places(device_count=4)
x = get_places(device_count=4)
self.assertIsNotNone(x)
print(str(program))
......
......@@ -97,7 +97,7 @@ class TestMomentumOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
opts = momentum_optimizer._create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3)
sgd_op = opts[-1]
......@@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass(
opts = momentum_optimizer._create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3)
sgd_op = opts[-1]
......@@ -214,8 +214,8 @@ class TestAdagradOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
opts = adagrad_optimizer._create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "adagrad"])
......@@ -278,8 +278,8 @@ class TestAdamOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
opts = adam_optimizer._create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 5)
self.assertEqual(
[op.type for op in opts],
......@@ -345,8 +345,8 @@ class TestAdamaxOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
opts = adamax_optimizer._create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 4)
self.assertEqual(
[op.type for op in opts],
......@@ -409,7 +409,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass(
opts = decayed_adagrad_optimizer._create_optimization_pass(
params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3)
self.assertEqual(
......@@ -475,8 +475,8 @@ class TestFtrlOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0)
opts = ftrl_optimizer.create_optimization_pass(params_grads, mul_out,
init_program)
opts = ftrl_optimizer._create_optimization_pass(params_grads, mul_out,
init_program)
self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "ftrl"])
......
......@@ -15,6 +15,7 @@
import unittest
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
import paddle.fluid.profiler as profiler
import numpy
......@@ -115,7 +116,7 @@ class BaseParallelForTest(unittest.TestCase):
if use_parallel:
thread_num = fluid.core.get_cuda_device_count(
) if use_gpu else 8
places = fluid.layers.get_places(thread_num)
places = get_places(thread_num)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
data = next(generator)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze"
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
# Correct: Inplace.
class TestSqueezeOpInplace1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(TestSqueezeOp):
def inti_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
# Correct: General.
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (-1, )
self.new_shape = (3, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1)
# Correct: Inplace.
class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, 2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册