25498a4263de51af3b49a29c75a6c74e
从零实现机器学习参数服务器框架(五)

简介

序号 名称 简介
1 框架架构 介绍整体框架架构、涉及模块
2 模型、数据计算并行 实现参数服务器中的Worker与Server
3 节点间通信接口 实现应用常用接口,使用ZeroMQ实现节点间通信
4 一致性控制 在Server端实现常见的一致性控制模式:BSP/ASP/SSP
5 容错实现 使用CheckPoint Rollback、心跳监测等方式实现计算容错
6 Yarn资源管理系统支持 编写Yarn Application以使用Yarn进行资源分配
7 机器学习算法实现 实现Logistic Regression与KMeans两种算法

截至上一篇为止(1-4),参数服务器框架已经有了基础的功能,我们已经可以使用这个框架来实现机器学习算法了, 在这个框架中我们可以并行的读取数据并保存在内存中、分布式的缓存Matrix与Vector参数、可供选择三种不同的一致性模式(BSP/ASP/SSP)来递归的计算梯度、利用ZeroMQ来实现节点间网络的通信。

然而对于一个可靠的框架,当运行在不同节点的进程出现故障时,集群应该具有恢复这个故障节点的能力。这种功能也称为容错。在这篇文章中,我将使用CheckPoint/RollBack、心跳监测机制来自动监测节点故障并恢复节点的训练进度。

容错实现

容错的实现我主要分成了7个不同的Stage:

  1. 各个节点在一定的间隔进行CheckPoint,备份数据
  2. 所有节点都完成CheckPoint
  3. 单个节点进程出现故障退出
  4. Master节点通过心跳机制监测到了故障
  5. 故障节点重启进程
  6. 故障节点重启成功并完成数据加载
  7. 所有节点恢复正常,基于之前CheckPoint备份的数据继续训练

以上就是涉及的具体方案,我对各个阶段耗时进行了测试:

Phase From To 耗时 耗时占比
1 启动CheckPoint 结束CheckPoint 18,771ms 9%
2 单个节点出现故障 Master监测到故障 50,023ms 24%
3 Master监测到故障 故障进程开始重启 10ms 1%
4 故障进程开始重启 故障进程重启成功 122,318ms 65%
5 故障进程重启成功 所有节点恢复正常,集群继续训练 30ms 1%

fault_tolerance

fault_tolerance

可以看到最耗时的阶段在故障进程开始重启故障进程重启成功,总共耗时占比65%。

Master

master

master

实现容错主要采用了主从结构,这样可以采用心跳监测机制来判断集群内是否有节点出现了故障。

master/master.hpp

class MasterThread : public Actor {
    public:
    MasterThread(uint32_t master_id, const std::vector<Node> &nodes) : Actor(master_id), nodes_(nodes) {}

    void Init() {
        serving_ = true;
        quit_count_ = 0;
        for (Node node : nodes_) {
            heartbeats_[node.id] = time(NULL);
        }
    }

    time_t GetHeartBeat(uint32_t node_id) {
        return heartbeats_[node_id];
    }

    void RollBack(int32_t failed_node_id) {
        rollback_func_(failed_node_id);
    }

    void SetRollBack(std::function<void(int32_t)> func) {
        rollback_func_ = func;
    }

    void SetRecoveringNodeId(int32_t recovering_node_id) {
        recovering_node_id_ = recovering_node_id;
    }

    bool IsRecovering() {
        return recovering_node_id_ != -1;
    }

    protected:
    virtual void Main() override;

    // 使用Map缓存每个节点的心跳更新状态
    std::unordered_map<uint32_t, time_t> heartbeats_;
};

master/master_thread.cpp

```cpp
void MasterThread::Main() {
Init();

// 无限循环,不断接收新的消息
while (serving_) {
    Message msg;
    work_queue_.WaitAndPop(&msg);

    if (msg.meta.flag == Flag::kHeartBeat) {
        // 更新节点的心跳信息
        heartbeats_[msg.meta.sender] = time(NULL);
top Created with Sketch.