实现线程池cached模式动态增加线程数量和超时回收线程

前言

线程池通过预先创建一定数量的线程并保存在内存中,可以避免频繁地创建和销毁线程,降低线程创建和销毁的开销

简化任务调度:只需要将任务提交给线程池,而不需要关心线程的创建、管理和销毁等细节。线程池会自动将任务分配给空闲的线程执行。

代码位置:https://gitee.com/zhongshield/thread_pool

调用方式

调用方继承Task基类,重写Run接口,

通过线程池提供的SubmiTask接口提交任务。

通过Result接收线程池的返回值

示例:实现通过线程池计算区间累加和并返回累加和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include "thread_pool.h"
#include <chrono>
#include <iostream>

class MyTask : public Task {
Any Run();
};

Any MyTask::Run()
{
std::cout << "tid: " << std::this_thread::get_id() << " begin" << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(2));
std::cout << "tid: " << std::this_thread::get_id() << " end" << std::endl;
return " ";
}

class MyTaskSecond : public Task {
public:
MyTaskSecond(int begin, int end) : begin_(begin), end_(end) {}

// 问题:如何设计run函数的返回值,可以表示任意类型
// Java Python Object类是所有类型的基类
// C++17 Any类型
Any Run();

private:
int begin_;
int end_;
};

Any MyTaskSecond::Run()
{
std::cout << "tid: " << std::this_thread::get_id() << " ***begin*** " << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(3));
int sum = 0;
for (int i = begin_; i < end_; i++) {
sum += i;
}
std::cout << "tid: " << std::this_thread::get_id() << " ***end*** " << std::endl;
return sum;
}

int main()
{
ThreadPool pool;
pool.SetMode(PoolMode::MODE_CACHED);
pool.Start();

// 获取任务的返回值
Result result1 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 1000));
Result result2 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1001, 2000));
Result result3 = pool.SubmitTask(std::make_shared<MyTaskSecond>(2001, 3000));

Result result4 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 3000));

Result result5 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 3000));
Result result6 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 3000));

int value1 = result1.GetValue().cast<int>(); // GetValue返回一个Any类型
int value2 = result2.GetValue().cast<int>();
int value3 = result3.GetValue().cast<int>();
int value4 = result4.GetValue().cast<int>();

std::cout << "value1 + value2 + value3: " << value1 + value2 + value3 << std::endl;
std::cout << "value4: " << value1 + value2 + value3 << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(20));
}

执行结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
z@zzz:~/Z/thread_pool$ g++ -std=c++17 main.cpp thread_pool.cpp -lpthread
z@zzz:~/Z/thread_pool$ ./a.out
*** create new thread ***
*** create new thread ***
tid: 140173518145280 尝试获取任务...
tid: 140173518145280 获取任务成功...
tid: 140173518145280 ***begin***
tid: 140173509752576 尝试获取任务...
tid: 140173509752576 获取任务成功...
tid: 140173509752576 ***begin***
tid: 140173501359872 尝试获取任务...
tid: 140173501359872 获取任务成功...
tid: 140173501359872 ***begin***
tid: 140173492967168 尝试获取任务...
tid: 140173492967168 获取任务成功...
tid: 140173492967168 ***begin***
tid: 140173484574464 尝试获取任务...
tid: 140173484574464 获取任务成功...
tid: 140173484574464 ***begin***
tid: 140173476181760 尝试获取任务...
tid: 140173476181760 获取任务成功...
tid: 140173476181760 ***begin***
tid: 140173518145280 ***end***
tid: 140173518145280 尝试获取任务...
tid: 140173509752576 ***end***
tid: 140173509752576 尝试获取任务...
tid: 140173501359872 ***end***
tid: 140173501359872 尝试获取任务...
tid: 140173492967168 ***end***
tid: 140173492967168 尝试获取任务...
value1 + value2 + value3: 4495500
value4: 4495500
tid: 140173484574464 ***end***
tid: 140173484574464 尝试获取任务...
tid: 140173476181760 ***end***
tid: 140173476181760 尝试获取任务...
threadId: 140173484574464 exit.
threadId: 140173492967168 exit.

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#ifndef THREAD_POOL_H
#define THREAD_POOL_H

#include <vector>
#include <queue>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <atomic>
#include <thread>
#include <unordered_map>

// Any类型:可以接收任意类型数据
class Any {
public:
Any() = default;
~Any() = default;
Any(Any&&) = default;
Any& operator=(Any&&) = default;

template <typename T>
Any(T data) : basePtr_(std::make_unique<Derive<T>>(data)) {}

// 这个方法能把Any对象里面存储的data数据提取出来
template <typename T>
T cast()
{
Derive<T>* devicePtr = dynamic_cast<Derive<T>*>(basePtr_.get());
if (devicePtr == nullptr) {
throw "error: type is unmatch.";
}
return devicePtr->data_;
}

private:
// 基类类型
class Base
{
private:
public:
virtual ~Base() = default;
};

// 派生类类型
template <typename T>
class Derive : public Base
{
public:
Derive(T data) : data_(data)
{}
T data_; // 保存了任意类型数据
};

private:
std::unique_ptr<Base> basePtr_;
};

// 实现一个信号量类
class Semaphore {
public:
Semaphore() = default;
Semaphore(int limit) : resLimit_(limit) {};
~Semaphore() = default;
Semaphore& operator=(Semaphore&&) = default;
Semaphore& operator=(const Semaphore&) = default;

void Wait()
{
std::unique_lock<std::mutex> lock(mtx_);
cond_.wait(lock, [&]()->bool { return resLimit_ > 0; });
resLimit_--;
}

void Post()
{
std::unique_lock<std::mutex> lock(mtx_);
resLimit_++;
cond_.notify_all();
}

private:
std::mutex mtx_;
std::condition_variable cond_;
int resLimit_ = 0;
};

class Task;

// 实现接收提交到线程的task任务执行完任务的返回值类型 Result
class Result {
public:
Result(std::shared_ptr<Task> task, bool isValid = true);
~Result() = default;
Result(const Result&) = default;
Result& operator=(const Result&) = default;
Result(Result&& result) = default;
Result& operator=(Result&& result) = default;
// SetValue方法 获取任务的返回值并设置到any_成员变量上
void SetValue(Any any);

// GetValue方法 用户调用该方法获取task的返回值
Any GetValue();

private:
Any any_; // 存储任务的返回值
Semaphore sem_; // 线程通信信号量
std::shared_ptr<Task> task_; // 指向对应获取返回值的所属任务
std::atomic_bool isValid_;
};

// 任务抽象基类
class Task {
public:
Task();
~Task() = default;
virtual Any Run() = 0;
void Exec();
void SetResult(Result* res);

private:
Result* resultPtr_; // Result对象的生命周期晚于Task对象
};

class Thread {
public:
using ThreadFunc = std::function<void(int)>;
Thread(ThreadFunc func);
~Thread();

// 启动线程
void Start();

// 获取线程id
int GetId() const;

private:
ThreadFunc func_;
static int generateId_;
int threadId_; // 保存线程id
};

/* 线程模式
* fixed 线程数量固定
* cached 线程数量不固定
*/
enum class PoolMode {
MODE_FIXED,
MODE_CACHED
};

class ThreadPool {
public:
ThreadPool();
~ThreadPool();
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;

// 开启线程池
void Start(int initThreadSize = 4);

// 设置线程池的工作模式:fixed 和 cached
void SetMode(PoolMode mode);
PoolMode GetMode();

// 设置task队列任务量上限阈值
void SetTaskMaxHold(size_t size);

// 向线程池提交任务
Result SubmitTask(std::shared_ptr<Task> task);

// 线程池是否正在运行
bool CheckPoolRunning();

// 设置线程池cached模式线程上限阈值
void SetCachedThreadMaxHold(size_t size);


private:
// 线程函数
void ThreadFunc(int threadId);

private:
// std::vector<std::unique_ptr<Thread>> threads_; // 线程列表
std::unordered_map<int, std::unique_ptr<Thread>> threads_;
size_t initThreadSize_; // 初始线程数量
std::queue<std::shared_ptr<Task>> taskQue_; // 任务队列
std::atomic_int taskSize_; // 任务数量
size_t taskThreadSizeMaxHold_; // 最大任务数量
std::mutex taskMtx_;
std::condition_variable notFull_; // 表示任务队列不满
std::condition_variable notEmpty_; // 表示任务队列不空
PoolMode mode_; // 当前线程池的工作模式
std::atomic_bool isPoolRunning_; // 线程池是否已运行
std::atomic_int idleThreadSize_; // 空闲线程的数量
size_t cachedThreadSizeHold_; // cached模式下线程数量上限
std::atomic_int cachedCurThreadSize_; //cached模式下线程总数量
};

#endif // THREAD_POOL_H
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#include "thread_pool.h"
#include <iostream>

const size_t TASK_SIZE_MAX_HOLD = 1024;
const size_t CACHED_THREAD_SIZE_MAX_HOLD = 8;
const int THREAD_MAX_IDLE_TIME = 10; // 单位:秒

ThreadPool::ThreadPool()
: taskSize_(0),
taskThreadSizeMaxHold_(TASK_SIZE_MAX_HOLD),
mode_(PoolMode::MODE_FIXED),
idleThreadSize_(0),
cachedThreadSizeHold_(CACHED_THREAD_SIZE_MAX_HOLD),
isPoolRunning_(false) {}

ThreadPool::~ThreadPool() {}

void ThreadPool::SetMode(PoolMode mode)
{
if (isPoolRunning_) {
return;
}
mode_ = mode;
}

PoolMode ThreadPool::GetMode()
{
return mode_;
}

void ThreadPool::SetTaskMaxHold(size_t size)
{
taskThreadSizeMaxHold_ = size;
}

void ThreadPool::Start(int initThreadSize)
{
isPoolRunning_ = true;

// 记录初始线程个数
initThreadSize_ = initThreadSize;
cachedCurThreadSize_ = initThreadSize;

for (int i = 0; i < initThreadSize_; i++) {
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
int threadId = ptr->GetId();
threads_.emplace(threadId, std::move(ptr));
}

for (int i = 0; i < initThreadSize_; i++) {
threads_[i]->Start();
}

idleThreadSize_ = initThreadSize;
}

Result ThreadPool::SubmitTask(std::shared_ptr<Task> task)
{
std::unique_lock<std::mutex> lock(taskMtx_);

// 线程通信 等待任务队列有空余 wait wait_for wait_until
// 用户提交任务 最场阻塞时间不能超过1s,否则提交任务失败,返回
if(!notFull_.wait_for(lock, std::chrono::seconds(1),
[&]()->bool { return taskQue_.size() < taskThreadSizeMaxHold_;})) {
// 表示等待1s后,条件依然没有满足
std::cerr << "task queue is full, submit task fail." << std::endl;
return Result(task, false);
}

taskQue_.emplace(task);
taskSize_++;
notEmpty_.notify_all();

// cached模式 任务处理比较紧急 场景:小而快的任务 需要根据任务数量和空闲线程数量来决定是否新创建线程来处理任务
if (mode_ == PoolMode::MODE_CACHED && taskSize_ > idleThreadSize_ && cachedCurThreadSize_ < cachedThreadSizeHold_) {
std::cout << " *** create new thread ***" << std::endl;

// 创建新的线程对象
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
int threadId = ptr->GetId();
threads_.emplace(threadId, std::move(ptr));
cachedCurThreadSize_++;
idleThreadSize_++;
threads_[threadId]->Start();
}

// 返回任务的Result对象
return Result(task);
}

bool ThreadPool::CheckPoolRunning()
{
return isPoolRunning_;
}

void ThreadPool::SetCachedThreadMaxHold(size_t size)
{
if (CheckPoolRunning()) {
return;
}
cachedThreadSizeHold_ = size;
}

// 定义线程函数 线程池的所有线程从任务队列里消费任务
void ThreadPool::ThreadFunc(int threadId)
{
auto lastTime = std::chrono::high_resolution_clock().now();

for(;;) {
std::shared_ptr<Task> task;
// 通过作用域{}来释放锁
{
// 先获取锁
std::unique_lock<std::mutex> lock(taskMtx_);

std::cout << "tid: " << std::this_thread::get_id() << " 尝试获取任务..." << std::endl;

// cached模式下,有可能已经创建了很多线程,当空闲时间超过60s,把多余的线程结束回收(超过initThreadSize_数量的线程要进行回收)
// 当前时间 - 上一次线程执行的时间 > 60s
if (mode_ == PoolMode::MODE_CACHED)
{
// 每一秒返回一次 怎么区分:是超时返回还是有任务执行返回 std::cv_status::timeout
while (taskQue_.size() == 0) {
if (std::cv_status::timeout == notEmpty_.wait_for(lock, std::chrono::seconds(1))) {
auto now = std::chrono::high_resolution_clock().now();
auto dur = std::chrono::duration_cast<std::chrono::seconds>(now - lastTime);
if (dur.count() >= THREAD_MAX_IDLE_TIME && cachedCurThreadSize_ > initThreadSize_) {
// 开始回收当前线程
// 修改记录线程数量的变量cachedCurThreadSize_
threads_.erase(threadId);
cachedCurThreadSize_--;
idleThreadSize_--;
std::cout << "threadId: " << std::this_thread::get_id() << " exit." << std::endl;
return;
}
}
}
notEmpty_.wait(lock, [&]()->bool { return !taskQue_.empty(); });

} else {
// 等待notEmpty条件变量
notEmpty_.wait(lock, [&]()->bool { return !taskQue_.empty(); });
}

idleThreadSize_--;

// 从任务队列中取一个任务出来
std::cout << "tid: " << std::this_thread::get_id() << " 获取任务成功..." << std::endl;
task = taskQue_.front();
taskQue_.pop();
taskSize_--;

// 如果依然有剩余任务,继续通知其它线程执行任务
if (taskQue_.size() > 0) {
notEmpty_.notify_all();
}

// 取出一个任务,通知任务队列已不满
notFull_.notify_all();
}

// 当前线程负责执行这个任务
// 【重要】:锁不要等到任务执行完再释放,所以上面用作用域{}及时释放锁
if (task != nullptr) {
// task->Run();
task->Exec();
}
idleThreadSize_++;
lastTime = std::chrono::high_resolution_clock().now(); // 更新线程执行完任务的时间
}
}

// Thread 方法实现
int Thread::generateId_ = 0;

Thread::Thread(ThreadFunc func) : func_(func), threadId_(generateId_++)
{}

Thread::~Thread() {}

void Thread::Start()
{
std::thread t(func_, threadId_);
t.detach();
}

int Thread::GetId() const {
return threadId_;
}

/// Task 方法实现
Task::Task() : resultPtr_(nullptr)
{}

void Task::Exec()
{
if (resultPtr_ != nullptr) {
resultPtr_->SetValue(Run()); // run发生多态调用
}
}

void Task::SetResult(Result* res)
{
resultPtr_ = res;
}

/// Result 方法实现
Result::Result(std::shared_ptr<Task> task, bool isValid) : task_(task), isValid_(isValid)
{
task_->SetResult(this);
}

Any Result::GetValue()
{
if (!isValid_) {
return " ";
}
sem_.Wait(); // task任务如果没有执行完毕,这里会阻塞用户线程
return std::move(any_);
}

void Result::SetValue(Any any)
{
any_ = std::move(any);
sem_.Post();
}

总结

线程池ThreadPool 提供SubmitTask接口用于提交任务。SubmitTask接口内部通过wait_for等待条件变量notFull_条件成立(条件为任务队列不满),等待1s后条件仍然不成立,打印任务提交失败。如果任务队列不满,提交的任务会放到任务队列中。

线程池ThreadPool 预先执行Start接口,Start接口内部会创建线程对象,并绑定线程函数为线程池ThreadPool 的成员函数ThreadFunc。

线程池ThreadPool 的成员函数ThreadFunc内部,会获取锁,通过wait等待条件变量notEmpty_条件成立(条件为任务队列不空),如果不空,从任务队列中取一个任务,在任务执行前释放锁(否则会阻塞其它线程获取锁,失去线程池的意义)。

线程池ThreadPool的SubmitTask接口返回值是Result,通过task指针构造Result对象 Result(task),并在Result构造函数中调用

task_->SetResult(this); 将Result和task联系起来,线程执行完任务,会通过task向Result设置返回值SetValue,SetValue内部会调用信号量的Post()释放锁。

调用方调用result.GetValue()获取返回值,GetValue()内部会调用信号量sem_.Wait(),如果任务还未执行完毕,会阻塞等待SetValue内部会调用信号量的Post()释放锁。

GetValue阻塞等待可能对调用方不友好,可以增加回调函数通知任务已执行完毕。或者在Result中增加标志/函数(用于判断任务是否执行完毕),从而不在GetValue()中进行阻塞等待。

add:

线程池是cached模式时,SubmitTask接口内部会根据任务数量和空闲线程数量来决定是否新创建线程来处理任务。

cached模式下,有可能已经创建了很多线程,当空闲时间超过60s,把多余的线程结束回收(超过initThreadSize_数量的线程要进行回收)