提交 3cf50b74 编写于 作者: S seiriosPlus

fix fuse

上级 66321576
...@@ -82,7 +82,7 @@ class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,7 +82,7 @@ class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
...@@ -150,6 +150,4 @@ REGISTER_OPERATOR( ...@@ -150,6 +150,4 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_adam, lookup_sparse_table_fuse_adam,
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>, ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>);
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h" #include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,8 +38,9 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T> ...@@ -37,8 +38,9 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate"); using paddle::framework::LoDTensor;
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad"); const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -56,8 +58,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T> ...@@ -56,8 +58,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
framework::SelectedRows tmp_grad_merge; framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr; const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *in_grad, merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true); &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge; grad_merge_ptr = &tmp_grad_merge;
...@@ -71,8 +73,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T> ...@@ -71,8 +73,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto grad_width = grad_v.dims()[1]; auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry"); // auto is_entry = context.Attr<bool>("is_entry");
auto tablename = context.Attr<std::string>("tablename"); auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names"); auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow"); auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow"); auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
...@@ -116,11 +118,11 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T> ...@@ -116,11 +118,11 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto &moment_1 = values[1]; auto &moment_1 = values[1];
auto &moment_2 = values[2]; auto &moment_2 = values[2];
T lr = *lr_; T lr_ = lr[0];
T beta1_ = beta1_pow->data<T>()[0]; T beta1_ = beta1_pow->data<T>()[0];
T beta2_ = beta2_pow->data<T>()[0]; T beta2_ = beta2_pow->data<T>()[0];
lr *= sqrt(1 - beta1_) / (1 - beta2_); lr_ *= sqrt(1 - beta1_) / (1 - beta2_);
for (size_t i = 0; i < in_rows.size(); i++) { for (size_t i = 0; i < in_rows.size(); i++) {
auto *m1_data = moment_1[i]->data(); auto *m1_data = moment_1[i]->data();
...@@ -131,7 +133,7 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T> ...@@ -131,7 +133,7 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto g = grad_v.data<T>()[grad_width * i + x]; auto g = grad_v.data<T>()[grad_width * i + x];
m1_data[x] = beta1_ * m1_data[x] + (1 - beta1_) * g; m1_data[x] = beta1_ * m1_data[x] + (1 - beta1_) * g;
m2_data[x] = beta2_ * m2_data[x] + (1 - beta2_) * g * g; m2_data[x] = beta2_ * m2_data[x] + (1 - beta2_) * g * g;
p_data[x] -= lr * (m1_data[x] / (sqrt(m2_data[x]) + epsilon)); p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
} }
} }
} }
......
...@@ -79,7 +79,7 @@ class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -79,7 +79,7 @@ class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows" "(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W."); "THe ids to be looked up in W.");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddAttr<bool>("is_entry", AddAttr<bool>("is_entry",
"(bool)" "(bool)"
"sparse table need entry"); "sparse table need entry");
...@@ -117,5 +117,4 @@ REGISTER_OPERATOR( ...@@ -117,5 +117,4 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_sgd, lookup_sparse_table_fuse_sgd,
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>, ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>);
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h" #include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -56,8 +57,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T> ...@@ -56,8 +57,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
framework::SelectedRows tmp_grad_merge; framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr; const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *in_grad, merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true); &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge; grad_merge_ptr = &tmp_grad_merge;
...@@ -71,8 +72,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T> ...@@ -71,8 +72,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
auto grad_width = grad_v.dims()[1]; auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry"); // auto is_entry = context.Attr<bool>("is_entry");
auto tablename = context.Attr<std::string>("tablename"); auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names"); auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
std::vector<std::vector<std::vector<float> *>> values; std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims; std::vector<int64_t> dims;
...@@ -88,15 +89,16 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T> ...@@ -88,15 +89,16 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
auto &params = values[0]; auto &params = values[0];
auto blas = math::GetBlas<DeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
std::vector<T> grads; std::vector<T> grads;
framework::TensorToVector(grad_v, context.device_context(), &grads); framework::TensorToVector(grad_v, ctx.device_context(), &grads);
blas.VMUL(grads, lr[0], grads); blas.SCAL(grads.size(), lr[0], grads.data());
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) { for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
blas.VSUB(grad_width, params[x], grads.data() + grad_width * x, params); blas.VSUB(grad_width, params[x]->data(), grads.data() + grad_width * x,
params[x]->data());
} }
} }
}; };
......
...@@ -657,13 +657,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -657,13 +657,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type == "sgd": if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]] grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
# remove origin optimzier op # remove origin optimzier op
block._remove_op(opt_idx) block._remove_op(opt_idx)
block._insert_op( block._insert_op(
opt_idx, opt_idx,
type="lookup_sparse_table_fuse_sgd", type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grad}, inputs={"Grad": grad,
"LearningRate": lr},
attrs={ attrs={
"is_entry": is_entry, "is_entry": is_entry,
"tablename": table_name, "tablename": table_name,
...@@ -672,6 +674,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -672,6 +674,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
elif op.type == "adam": elif op.type == "adam":
grad = main_program.global_block().vars[op.input("Grad")[0]] grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[ beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[
0]] 0]]
beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[ beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[
...@@ -693,6 +696,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -693,6 +696,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
type="lookup_sparse_table_fuse_adam", type="lookup_sparse_table_fuse_adam",
inputs={ inputs={
"Grad": grad, "Grad": grad,
"LearningRate": lr,
"Beta1Pow": beta1_pow, "Beta1Pow": beta1_pow,
"Beta2Pow": beta2_pow "Beta2Pow": beta2_pow
}, },
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册