简介

前文我们介绍了线程池,已经给大家提供了一个完整的线程池封装了,本节跟着《C++ 并发编程实战》一书中作者的思路,看看他的线程池的实现,以此作为补充

轮询方式的线程池

配合我们之前封装的线程安全队列threadsafe_queue

#include <mutex>
#include <queue>

template<typename T>
class threadsafe_queue
{
private:
    struct node
    {
        std::shared_ptr<T> data;
        std::unique_ptr<node> next;
        node* prev;
    };

    std::mutex head_mutex;
    std::unique_ptr<node> head;
    std::mutex tail_mutex;
    node* tail;
    std::condition_variable data_cond;
    std::atomic_bool  bstop;

    node* get_tail()
    {
        std::lock_guard<std::mutex> tail_lock(tail_mutex);
        return tail;
    }
    std::unique_ptr<node> pop_head()   
    {
        std::unique_ptr<node> old_head = std::move(head);
        head = std::move(old_head->next);
        return old_head;
    }

    std::unique_lock<std::mutex> wait_for_data()   
    {
        std::unique_lock<std::mutex> head_lock(head_mutex);
        data_cond.wait(head_lock,[&] {return head.get() != get_tail() || bstop.load() == true; });
        return std::move(head_lock);   
    }

        std::unique_ptr<node> wait_pop_head()
        {
            std::unique_lock<std::mutex> head_lock(wait_for_data());   
            if (bstop.load()) {
                return nullptr;
            }

                return pop_head();
        }
        std::unique_ptr<node> wait_pop_head(T& value)
        {
            std::unique_lock<std::mutex> head_lock(wait_for_data());  
            if (bstop.load()) {
                return nullptr;
            }
            value = std::move(*head->data);
            return pop_head();
        }

        std::unique_ptr<node> try_pop_head()
        {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if (head.get() == get_tail())
            {
                return std::unique_ptr<node>();
            }
            return pop_head();
        }
        std::unique_ptr<node> try_pop_head(T& value)
        {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if (head.get() == get_tail())
            {
                return std::unique_ptr<node>();
            }
            value = std::move(*head->data);
            return pop_head();
        }
public:

    threadsafe_queue() :  // ⇽-- - 1
        head(new node), tail(head.get())
    {}

    ~threadsafe_queue() {
        bstop.store(true);
        data_cond.notify_all();
    }

    threadsafe_queue(const threadsafe_queue& other) = delete;
    threadsafe_queue& operator=(const threadsafe_queue& other) = delete;

    void Exit() {
        bstop.store(true);
        data_cond.notify_all();
    }

    bool wait_and_pop_timeout(T& value) {
        std::unique_lock<std::mutex> head_lock(head_mutex);
        auto res = data_cond.wait_for(head_lock, std::chrono::milliseconds(100),
                [&] {return head.get() != get_tail() || bstop.load() == true; });
        if (res == false) {
            return false;
        }

        if (bstop.load()) {
            return false;
        }

        value = std::move(*head->data);    
        head = std::move(head->next);
        return true;
    }

    std::shared_ptr<T> wait_and_pop() //  <------3
    {
        std::unique_ptr<node> const old_head = wait_pop_head();
        if (old_head == nullptr) {
            return nullptr;
        }
        return old_head->data;
    }

    bool  wait_and_pop(T& value)  //  <------4
    {
        std::unique_ptr<node> const old_head = wait_pop_head(value);
        if (old_head == nullptr) {
            return false;
        }
        return true;
    }


    std::shared_ptr<T> try_pop()
    {
        std::unique_ptr<node> old_head = try_pop_head();
        return old_head ? old_head->data : std::shared_ptr<T>();
    }

    bool try_pop(T& value)
    {
        std::unique_ptr<node> const old_head = try_pop_head(value);
        if (old_head) {
            return true;
        }
        return false;
    }

    bool empty()
    {
        std::lock_guard<std::mutex> head_lock(head_mutex);
        return (head.get() == get_tail());
    }

    void push(T new_value) //<------2
    {
        std::shared_ptr<T> new_data(
            std::make_shared<T>(std::move(new_value)));
        std::unique_ptr<node> p(new node);
        {
            std::lock_guard<std::mutex> tail_lock(tail_mutex);
            tail->data = new_data;
            node* const new_tail = p.get();
            new_tail->prev = tail;

            tail->next = std::move(p);

            tail = new_tail;
        }

        data_cond.notify_one();
    }

    bool try_steal(T& value) {
        std::unique_lock<std::mutex> tail_lock(tail_mutex,std::defer_lock);
        std::unique_lock<std::mutex>  head_lock(head_mutex, std::defer_lock);
        std::lock(tail_lock, head_lock);
        if (head.get() == tail)
        {
            return false;
        }

        node* prev_node = tail->prev;
        value = std::move(*(prev_node->data));
        tail = prev_node;
        tail->next = nullptr;
        return true;
    }
};

我们封装了一个简单轮询的线程池

#include <atomic>
#include "ThreadSafeQue.h"
#include "join_thread.h"

class simple_thread_pool
{
    std::atomic_bool done;
    //⇽-- - 1
    threadsafe_queue<std::function<void()> > work_queue; 
    //⇽-- - 2
    std::vector<std::thread> threads; 
    //⇽-- - 3
    join_threads joiner;    
    void worker_thread()
    {
        //⇽-- - 4
        while (!done)    
        {
            std::function<void()> task;
            //⇽-- - 5
            if (work_queue.try_pop(task))    
            {
                //⇽-- - 6
                task();    
            }
            else
            {
                //⇽-- - 7
                std::this_thread::yield();    
            }
        }
    }

    simple_thread_pool() :
        done(false), joiner(threads)
    {
        //⇽--- 8
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            for (unsigned i = 0; i < thread_count; ++i)
            {
                //⇽-- - 9
                threads.push_back(std::thread(&simple_thread_pool::worker_thread, this));
            }
        }
        catch (...)
        {
            //⇽-- - 10
            done = true;
            throw;
        }
    }
public:
    static simple_thread_pool& instance() {
       static  simple_thread_pool pool;
       return pool;
    }
    ~simple_thread_pool()
    {
        //⇽-- - 11
        done = true;     
        for (unsigned i = 0; i < threads.size(); ++i)
        {
            //⇽-- - 9
            threads[i].join();
        }
    }
    template<typename FunctionType>
    void submit(FunctionType f)
    {
        //⇽-- - 12
        work_queue.push(std::function<void()>(f));    
    }
};
  1. worker_thread 即为线程的回调函数,回调函数内从队列中取出任务并处理,如果没有任务则调用yield释放cpu资源。

  2. submit函数比较简单,投递了一个返回值为void,参数为void的任务。这和我们之前自己设计的线程池(可执行任意参数类型,返回值不限的函数)相比功能稍差了一些。

获取任务完成结果

因为外部投递任务给线程池后要获取线程池执行任务的结果,我们之前自己设计的线程池采用的是future和decltype推断函数返回值的方式构造一个返回类型的future。

这里作者先封装一个可调用对象的类

class function_wrapper
{
    struct impl_base {
        virtual void call() = 0;
        virtual ~impl_base() {}
    };
    std::unique_ptr<impl_base> impl;
    template<typename F>
    struct impl_type : impl_base
    {
        F f;
        impl_type(F&& f_) : f(std::move(f_)) {}
        void call() { f(); }
    };
public:
    template<typename F>
    function_wrapper(F&& f) :
        impl(new impl_type<F>(std::move(f)))
    {}
    void operator()() { impl->call(); }
    function_wrapper() = default;
    function_wrapper(function_wrapper&& other) :
        impl(std::move(other.impl))
    {}
    function_wrapper& operator=(function_wrapper&& other)
    {
        impl = std::move(other.impl);
        return *this;
    }
    function_wrapper(const function_wrapper&) = delete;
    function_wrapper(function_wrapper&) = delete;
    function_wrapper& operator=(const function_wrapper&) = delete;
};
  1. impl_base 是一个基类,内部有一个纯虚函数call,以及一个虚析构,这样可以通过delete 基类指针动态析构子类对象。

  2. impl_type 继承了impl_base类,内部包含了一个可调用对象f,并且实现了构造函数和call函数,call内部调用可调用对象f。

  3. function_wrapper 内部有智能指针impl_base类型的unique_ptr变量impl, function_wrapper构造函数根据可调用对象f构造impl

  4. function_wrapper支持移动构造不支持拷贝和赋值。function_wrapper本质上就是当作task给线程池执行的。

可获取任务执行状态的线程池如下

class future_thread_pool
{
private:
    void worker_thread()
    {
        while (!done)
        {
            function_wrapper task;    

                if (work_queue.try_pop(task))
                {
                    task();
                }
                else
                {
                    std::this_thread::yield();
                }
        }
    }
public:

    static future_thread_pool& instance() {
        static  future_thread_pool pool;
        return pool;
    }
    ~future_thread_pool()
    {
        //⇽-- - 11
        done = true;
        for (unsigned i = 0; i < threads.size(); ++i)
        {
            //⇽-- - 9
            threads[i].join();
        }
    }

    template<typename FunctionType>
    std::future<typename std::result_of<FunctionType()>::type>   
        submit(FunctionType f)
    {
        typedef typename std::result_of<FunctionType()>::type result_type;   
            std::packaged_task<result_type()> task(std::move(f));   
            std::future<result_type> res(task.get_future());    
            work_queue.push(std::move(task));    
            return res;   
    }

private:
    future_thread_pool() :
        done(false), joiner(threads)
    {
        //⇽--- 8
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            for (unsigned i = 0; i < thread_count; ++i)
            {
                //⇽-- - 9
                threads.push_back(std::thread(&future_thread_pool::worker_thread, this));
            }
        }
        catch (...)
        {
            //⇽-- - 10
            done = true;
            throw;
        }
    }

    std::atomic_bool done;
    //⇽-- - 1
    threadsafe_queue<function_wrapper> work_queue;
    //⇽-- - 2
    std::vector<std::thread> threads;
    //⇽-- - 3
    join_threads joiner;
};
  1. worker_thread内部从队列中pop任务并执行,如果没有任务则交出cpu资源。

  2. submit函数返回值为std::future<typename std::result_of<FunctionType()>::type>类型,通过std::result_of<FunctionType()>推断出函数执行的结果,然后通过::type推断出结果的类型,并且根据这个类型构造future,这样调用者就可以在投递完任务获取任务的执行结果了。

  3. submit函数内部我们将函数执行的结果类型定义为result_type类型,并且利用f构造一个packaged_task任务。通过task返回一个future给外部调用者,然后我们调用队列的push将task放入队列,注意队列存储的是function_wrapper,这里是利用task隐式构造了function_wrapper类型的对象。

利用条件变量等待

当我们的任务队列中没有任务的时候,可以让线程挂起,然后等待有任务投递到队列后在激活线程处理

class notify_thread_pool
{
private:
    void worker_thread()
    {
        while (!done)
        {

            auto task_ptr = work_queue.wait_and_pop();
            if (task_ptr == nullptr) {
                continue;
            }

            (*task_ptr)();
        }
    }
public:

    static notify_thread_pool& instance() {
        static  notify_thread_pool pool;
        return pool;
    }
    ~notify_thread_pool()
    {
        //⇽-- - 11
        done = true;
        work_queue.Exit();
        for (unsigned i = 0; i < threads.size(); ++i)
        {
            //⇽-- - 9
            threads[i].join();
        }
    }

    template<typename FunctionType>
    std::future<typename std::result_of<FunctionType()>::type>   
        submit(FunctionType f)
    {
        typedef typename std::result_of<FunctionType()>::type result_type;   
            std::packaged_task<result_type()> task(std::move(f));   
            std::future<result_type> res(task.get_future());    
            work_queue.push(std::move(task));    
            return res;   
    }

private:
    notify_thread_pool() :
        done(false), joiner(threads)
    {
        //⇽--- 8
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            for (unsigned i = 0; i < thread_count; ++i)
            {
                //⇽-- - 9
                threads.push_back(std::thread(&notify_thread_pool::worker_thread, this));
            }
        }
        catch (...)
        {
            //⇽-- - 10
            done = true;
            work_queue.Exit();
            throw;
        }
    }

    std::atomic_bool done;
    //⇽-- - 1
    threadsafe_queue<function_wrapper> work_queue;
    //⇽-- - 2
    std::vector<std::thread> threads;
    //⇽-- - 3
    join_threads joiner;
};
  1. worker_thread内部调用了work_queue的wait_and_pop函数,如果队列中有任务直接返回,如果没任务则挂起。

  2. 另外我们在线程池的析构函数和异常处理时都增加了work_queue.Exit(); 这需要在我们的线程安全队列中增加Exit函数通知线程唤醒,因为线程发现队列为空会阻塞住。

void Exit() {
    bstop.store(true);
    data_cond.notify_all();
}

避免争夺

我们的任务队列只有一个,当向任务队列频繁投递任务,线程池中其他线程从队列中获取任务,队列就会频繁加锁和解锁,一般情况下性能不会有什么损耗,但是如果投递的任务较多,我们可以采取分流的方式,创建多个任务队列(可以和线程池中线程数相等),将任务投递给不同的任务队列,每个线程消费自己的队列即可,这样减少了线程间取任务的冲突。

#include "ThreadSafeQue.h"
#include <future>
#include "ThreadSafeQue.h"
#include "join_thread.h"
#include "FutureThreadPool.h"

class parrallen_thread_pool
{
private:

    void worker_thread(int index)
    {
        while (!done)
        {

            auto task_ptr = thread_work_ques[index].wait_and_pop();
            if (task_ptr == nullptr) {
                continue;
            }

            (*task_ptr)();
        }
    }
public:

    static parrallen_thread_pool& instance() {
        static  parrallen_thread_pool pool;
        return pool;
    }
    ~parrallen_thread_pool()
    {
        //⇽-- - 11
        done = true;
        for (unsigned i = 0; i < thread_work_ques.size(); i++) {
            thread_work_ques[i].Exit();
        }

        for (unsigned i = 0; i < threads.size(); ++i)
        {
            //⇽-- - 9
            threads[i].join();
        }
    }

    template<typename FunctionType>
    std::future<typename std::result_of<FunctionType()>::type>
        submit(FunctionType f)
    {
        int index = (atm_index.load() + 1) % thread_work_ques.size();
        atm_index.store(index);
        typedef typename std::result_of<FunctionType()>::type result_type;
        std::packaged_task<result_type()> task(std::move(f));
        std::future<result_type> res(task.get_future());
        thread_work_ques[index].push(std::move(task));
        return res;
    }

private:
    parrallen_thread_pool() :
        done(false), joiner(threads), atm_index(0)
    {
        //⇽--- 8
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            thread_work_ques = std::vector < threadsafe_queue<function_wrapper>>(thread_count);

            for (unsigned i = 0; i < thread_count; ++i)
            {
                //⇽-- - 9
                threads.push_back(std::thread(&parrallen_thread_pool::worker_thread, this, i));
            }
        }
        catch (...)
        {
            //⇽-- - 10
            done = true;
            for (int i = 0; i < thread_work_ques.size(); i++) {
                thread_work_ques[i].Exit();
            }
            throw;
        }
    }

    std::atomic_bool done;
    //全局队列
    std::vector<threadsafe_queue<function_wrapper>> thread_work_ques;

    //⇽-- - 2
    std::vector<std::thread> threads;
    //⇽-- - 3
    join_threads joiner;
    std::atomic<int>  atm_index;
};
  1. 我们将任务队列变为多个//全局队列 std::vector<threadsafe_queue<function_wrapper>> thread_work_ques;.

  2. commit的时候根据atm_index索引自增后对总大小取余将任务投递给不同的队列。

  3. worker_thread增加了索引参数,每个线程的在回调的时候会根据自己的索引取出对应队列中的任务进行执行。

任务窃取

当本线程队列中的任务处理完了,它可以去别的线程的任务队列中看看是否有没处理的任务,帮助其他线程处理任务,简称任务窃取。

#include "ThreadSafeQue.h"
#include <future>
#include "ThreadSafeQue.h"
#include "join_thread.h"
#include "FutureThreadPool.h"

class steal_thread_pool
{
private:

    void worker_thread(int index)
    {
        while (!done)
        {
            function_wrapper wrapper;
            bool pop_res = thread_work_ques[index].try_pop(wrapper);
            if (pop_res) {
                wrapper();
                continue;
            }

            bool steal_res = false;
            for (int i = 0; i < thread_work_ques.size(); i++) {
                if (i == index) {
                    continue;
                }

                steal_res  = thread_work_ques[i].try_pop(wrapper);
                if (steal_res) {
                    wrapper();
                    break;
                }

            }

            if (steal_res) {
                continue;
            }

            std::this_thread::yield();
        }
    }
public:

    static steal_thread_pool& instance() {
        static  steal_thread_pool pool;
        return pool;
    }
    ~steal_thread_pool()
    {
        //⇽-- - 11
        done = true;
        for (unsigned i = 0; i < thread_work_ques.size(); i++) {
            thread_work_ques[i].Exit();
        }

        for (unsigned i = 0; i < threads.size(); ++i)
        {
            //⇽-- - 9
            threads[i].join();
        }
    }

    template<typename FunctionType>
    std::future<typename std::result_of<FunctionType()>::type>
        submit(FunctionType f)
    {
        int index = (atm_index.load() + 1) % thread_work_ques.size();
        atm_index.store(index);
        typedef typename std::result_of<FunctionType()>::type result_type;
        std::packaged_task<result_type()> task(std::move(f));
        std::future<result_type> res(task.get_future());
        thread_work_ques[index].push(std::move(task));
        return res;
    }

private:
    steal_thread_pool() :
        done(false), joiner(threads), atm_index(0)
    {
        //⇽--- 8
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            thread_work_ques = std::vector < threadsafe_queue<function_wrapper>>(thread_count);

            for (unsigned i = 0; i < thread_count; ++i)
            {
                //⇽-- - 9
                threads.push_back(std::thread(&steal_thread_pool::worker_thread, this, i));
            }
        }
        catch (...)
        {
            //⇽-- - 10
            done = true;
            for (int i = 0; i < thread_work_ques.size(); i++) {
                thread_work_ques[i].Exit();
            }
            throw;
        }
    }

    std::atomic_bool done;
    //全局队列
    std::vector<threadsafe_queue<function_wrapper>> thread_work_ques;

    //⇽-- - 2
    std::vector<std::thread> threads;
    //⇽-- - 3
    join_threads joiner;
    std::atomic<int>  atm_index;
};
  1. worker_thread中本线程会先处理自己队列中的任务,如果自己队列中没有任务则从其它线程的任务队列中获取任务。如果都没有则交出cpu资源。

  2. 为了实现try_steal的功能,我们需要修改线程安全队列threadsafe_queue,增加try_steal函数

bool try_steal(T& value) {
    std::unique_lock<std::mutex> tail_lock(tail_mutex,std::defer_lock);
    std::unique_lock<std::mutex>  head_lock(head_mutex, std::defer_lock);
    std::lock(tail_lock, head_lock);
    if (head.get() == tail)
    {
        return false;
    }

    node* prev_node = tail->prev;
    value = std::move(*(prev_node->data));
    tail = prev_node;
    tail->next = nullptr;
    return true;
}

因为try_steal是从队列的尾部弹出数据,为了防止此时有其他线程从头部弹出数据造成操作同一个节点,或者其他线程弹出头部数据后接着修改头部节点为下一个节点,此时本线程正在弹出尾部节点,而尾部节点正好是头部的下一个节点造成数据混乱,此时加了两把锁,对头部和尾部都加锁。

我们这里所说的弹出尾部节点不是弹出tail,而是tail的前一个节点,因为tail是尾部表示一个空节点,tail前边的节点才是尾部数据的节点,为了实现反向查找,我们为node增加了prev指针

struct node
{
    std::shared_ptr<T> data;
    std::unique_ptr<node> next;
    node* prev;
};

所以在push节点的时候也要把这个节点的prev指针指向前一个节点

void push(T new_value) //<------2
{
    std::shared_ptr<T> new_data(
    std::make_shared<T>(std::move(new_value)));
    std::unique_ptr<node> p(new node);
    {
        std::lock_guard<std::mutex> tail_lock(tail_mutex);
        tail->data = new_data;
        node* const new_tail = p.get();
        new_tail->prev = tail;
        tail->next = std::move(p);
        tail = new_tail;
    }
        data_cond.notify_one();
}

整体来说steal版本的线程池就这些内容和前边变化不大。

测试

测试用例已经在源代码中写好,感兴趣可以看下

源码链接:

https://gitee.com/secondtonone1/boostasio-learn/tree/master/concurrent/day22-ThreadPool

视频链接:

https://space.bilibili.com/271469206/channel/collectiondetail?sid=1623290

results matching ""

    No results matching ""