提交 641b4c0f 编写于 作者: T typhoonzero

wip

上级 74b12288
...@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> { ...@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows) // 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1]; auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{ math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
new framework::SelectedRows()}; auto grad_merge = merge_func(context, grad);
grad_merge->set_rows(merge_rows); auto& merge_rows = grad_merge.rows();
grad_merge->set_height(grad.height()); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
for (size_t i = 0; i < grad_rows.size(); i++) {
size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]);
for (int64_t j = 0; j < grad_width; j++) {
grad_merge_data[grad_merge_i * grad_width + j] +=
grad_data[i * grad_width + j];
}
}
// 2. m += g_m * g_m // 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{ math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
new framework::SelectedRows()}; auto grad_square = sqare_func(context, grad_merge, grad_merge);
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor; math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, grad_square, moment);
// 3. update parameter // 3. update parameter
auto* lr = learning_rate.data<T>(); auto* lr = learning_rate.data<T>();
......
...@@ -78,51 +78,17 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -78,51 +78,17 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows) // 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1]; auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{ math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
new framework::SelectedRows()}; auto grad_merge = merge_func(context, grad);
grad_merge->set_rows(merge_rows); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
grad_merge->set_height(grad.height()); auto& merge_rows = grad_merge.rows;
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, grad_rows.size());
MergeGradKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_data, grad.rows().data(),
grad_merge_data, grad_merge->rows().data(),
grad_merge->rows().size(), grad_width);
// 2. m += g_m * g_m // 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{ math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
new framework::SelectedRows()}; auto grad_square = sqare_func(context, grad_merge, grad_merge);
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor; math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, grad_square, moment);
// 3. update parameter // 3. update parameter
auto* lr = learning_rate.data<T>(); auto* lr = learning_rate.data<T>();
......
...@@ -16,11 +16,14 @@ limitations under the License. */ ...@@ -16,11 +16,14 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/for_range.h" #include "paddle/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename T> template <typename T>
struct AdamFunctor { struct AdamFunctor {
T beta1_; T beta1_;
...@@ -134,8 +137,6 @@ struct SparseAdamFunctor { ...@@ -134,8 +137,6 @@ struct SparseAdamFunctor {
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// IMPORTANT:
// FIXME(typhoonzero): row id may be duplicate
moment1_out_[rows_[i] * row_numel_ + j] = mom1; moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2; moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p; param_out_[rows_[i] * row_numel_ + j] = p;
...@@ -191,10 +192,14 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -191,10 +192,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad"); Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
auto& grad_tensor = grad.value(); // merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge =
merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad.rows().data(); auto* rows = grad_merge.rows().data();
auto row_numel = grad_tensor.numel() / grad.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor( SparseAdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(), beta1, beta2, epsilon, beta1_pow.template data<T>(),
...@@ -206,7 +211,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -206,7 +211,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel); param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
grad.rows().size()); grad_merge.rows().size());
for_range(functor); for_range(functor);
} else { } else {
PADDLE_THROW("Variable type not supported by adam_op"); PADDLE_THROW("Variable type not supported by adam_op");
......
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h" #include <set>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -193,27 +195,25 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) { ...@@ -193,27 +195,25 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
template <typename T> template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> { struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input) {
framework::SelectedRows* out) { framework::SelectedRows out;
auto input_rows = input.rows(); auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1]; auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{ out.set_rows(merge_rows);
// new framework::SelectedRows()}; out.set_height(input.height());
out->set_rows(merge_rows); out.mutable_value()->mutable_data<T>(
out->set_height(input.height());
out->mutable_value()->mutable_data<T>(
framework::make_ddim( framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}), {static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor; math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0); constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>(); auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>(); auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) { for (size_t i = 0; i < input_rows.size(); i++) {
...@@ -222,6 +222,74 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -222,6 +222,74 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data[out_i * input_width + j] += input_data[i * input_width + j]; out_data[out_i * input_width + j] += input_data[i * input_width + j];
} }
} }
return out;
}
};
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> {
framework::Tensor operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::ADD:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUB:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] -=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUBBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] -
input2_data[in1_rows[i] * in1_row_numel + j];
break;
case ScatterOps::MUL:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] *=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIV:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] /=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIVBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] /
input2_data[in1_rows[i] * in1_row_numel + j];
break;
}
} }
}; };
......
...@@ -252,27 +252,26 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, ...@@ -252,27 +252,26 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T> template <typename T>
struct MergeAdd<platform::GPUDeviceContext, T> { struct MergeAdd<platform::GPUDeviceContext, T> {
void operator()(const platform::GPUDeviceContext& context, framework::SelectedRows operator()(const platform::GPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input) {
framework::SelectedRows* out) { framework::SelectedRows out;
auto input_rows = input.rows(); auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1]; auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{
// new framework::SelectedRows()}; out.set_rows(merge_rows);
out->set_rows(merge_rows); out.set_height(input.height());
out->set_height(input.height()); out.mutable_value()->mutable_data<T>(
out->mutable_value()->mutable_data<T>(
framework::make_ddim( framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}), {static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor; math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0); constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>(); auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>(); auto* input_data = input.value().data<T>();
const int block_size = 256; const int block_size = 256;
...@@ -283,11 +282,96 @@ struct MergeAdd<platform::GPUDeviceContext, T> { ...@@ -283,11 +282,96 @@ struct MergeAdd<platform::GPUDeviceContext, T> {
T, 256><<<grid1, threads, 0, T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data, .stream()>>>(input_data, input.rows().data(), out_data,
out->rows().data(), out->rows().size(), out.rows().data(), out.rows().size(),
input_width); input_width);
return out;
} }
}; };
template struct MergeAdd<platform::GPUDeviceContext, float>;
template struct MergeAdd<platform::GPUDeviceContext, double>;
template struct MergeAdd<platform::GPUDeviceContext, int>;
template struct MergeAdd<platform::GPUDeviceContext, int64_t>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
const int64_t* rows, const ScatterOps& op,
T* tensor_out, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index];
}
break;
case ScatterOps::ADD:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] += selected_rows[index];
}
break;
case ScatterOps::SUB:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] -= selected_rows[index];
}
break;
case ScatterOps::SUBBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] - tensor_out[index];
}
break;
case ScatterOps::MUL:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] *= selected_rows[index];
}
break;
case ScatterOps::DIV:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] /= selected_rows[index];
}
break;
case ScatterOps::DIVBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] / tensor_out[index];
}
break;
}
}
template <typename T>
struct UpdateToTensor<platform::GPUDeviceContext, T> {
framework::Tensor operator()(const platform::GPUDeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called.
auto merged_in1 = MergeAdd()(context, input1);
auto in1_height = merged_in1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = merged_in1.value();
auto& in1_rows = merged_in1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size());
UpdateToTensorKernel<
T, PADDLE_CUDA_NUM_THREADS><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.data(), op, in2_data, in1_row_numel);
}
};
} // namespace scatter } // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#define INLINE_FOR2(sizei, sizej) \
for (int64_t i = 0; i < sizei; i++) \
for (int64_t j = 0; j < sizej; j++)
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -55,50 +59,76 @@ struct SelectedRowsAddToTensor { ...@@ -55,50 +59,76 @@ struct SelectedRowsAddToTensor {
namespace scatter { namespace scatter {
// functors for manuplating SelectedRows data // functors for manuplating SelectedRows data
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct MergeAdd { struct MergeAdd {
// unary functor, merge by adding duplicated rows in // unary functor, merge by adding duplicated rows in
// the input SelectedRows object. // the input SelectedRows object.
void operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input);
framework::SelectedRows* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct Add { struct Add {
void operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1, const framework::SelectedRows& input1,
const framework::SelectedRows& input2, const framework::SelectedRows& input2) {
framework::SelectedRows* out) { framework::SelectedRows out;
out->set_rows(input1.rows()); out.set_rows(input1.rows());
out->set_height(input1.height()); out.set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1.value().dims(), out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace()); context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value())); auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value()); auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value()); auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2; e_out.device(*context.eigen_device()) = e_in1 + e_in2;
return out;
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct Mul { struct Mul {
void operator()(const DeviceContext& context, // multiply two SelectedRows
const framework::SelectedRows& input1, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input2, const framework::SelectedRows& input1,
framework::SelectedRows* out) { const framework::SelectedRows& input2) {
out->set_rows(input1.rows()); framework::SelectedRows out;
out->set_height(input1.height()); out.set_rows(input1.rows());
out->mutable_value()->mutable_data<T>(input1.value().dims(), out.set_height(input1.height());
context.GetPlace()); out.mutable_value()->mutable_data<T>(input1.value().dims(),
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value())); context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value()); auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value()); auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2; e_out.device(*context.eigen_device()) = e_in1 * e_in2;
return out;
}
// multiply scalar to SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const T input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
e_out.device(*context.eigen_device()) = input2 * e_in1;
return out;
} }
}; };
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor
template <typename DeviceContext, typename T>
struct UpdateToTensor {
framework::Tensor operator()(const DeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2);
};
} // namespace scatter } // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -285,7 +285,6 @@ class TestSparseAdamOp(unittest.TestCase): ...@@ -285,7 +285,6 @@ class TestSparseAdamOp(unittest.TestCase):
j = 0 j = 0
while j < self.row_numel: while j < self.row_numel:
pos = row_id * self.row_numel + j pos = row_id * self.row_numel + j
print(actual[pos] - np_array[pos]) / actual[pos]
self.assertLess((actual[pos] - np_array[pos]) / actual[pos], self.assertLess((actual[pos] - np_array[pos]) / actual[pos],
0.00001) 0.00001)
j += 1 j += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册