实现接收线程池任务返回值务

前言

线程池通过预先创建一定数量的线程并保存在内存中,可以避免频繁地创建和销毁线程,降低线程创建和销毁的开销

简化任务调度:只需要将任务提交给线程池,而不需要关心线程的创建、管理和销毁等细节。线程池会自动将任务分配给空闲的线程执行。

代码位置:https://gitee.com/zhongshield/thread_pool

调用方式

调用方继承Task基类,重写Run接口,

通过线程池提供的SubmiTask接口提交任务。

通过Result接收线程池的返回值

示例:实现通过线程池计算区间累加和并返回累加和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "thread_pool.h"
#include <chrono>
#include <iostream>

class MyTask : public Task {
Any Run();
};

Any MyTask::Run()
{
std::cout << "tid: " << std::this_thread::get_id() << " begin" << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(2));
std::cout << "tid: " << std::this_thread::get_id() << " end" << std::endl;
return " ";
}

class MyTaskSecond : public Task {
public:
MyTaskSecond(int begin, int end) : begin_(begin), end_(end) {}

// 问题:如何设计run函数的返回值,可以表示任意类型
// Java Python Object类是所有类型的基类
// C++17 Any类型
Any Run();

private:
int begin_;
int end_;
};

Any MyTaskSecond::Run()
{
std::cout << "tid: " << std::this_thread::get_id() << " ***begin*** " << std::endl;
int sum = 0;
for (int i = begin_; i < end_; i++) {
sum += i;
}
std::cout << "tid: " << std::this_thread::get_id() << " ***end*** " << std::endl;
return sum;
}

int main()
{
ThreadPool pool;
pool.Start();

// 获取任务的返回值
Result result1 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 1000));
Result result2 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1001, 2000));
Result result3 = pool.SubmitTask(std::make_shared<MyTaskSecond>(2001, 3000));

Result result4 = pool.SubmitTask(std::make_shared<MyTaskSecond>(1, 3000));

int value1 = result1.GetValue().cast<int>(); // GetValue返回一个Any类型
int value2 = result2.GetValue().cast<int>();
int value3 = result3.GetValue().cast<int>();
int value4 = result4.GetValue().cast<int>();

std::cout << "value1 + value2 + value3: " << value1 + value2 + value3 << std::endl;
std::cout << "value4: " << value1 + value2 + value3 << std::endl;

// pool.SubmitTask(std::make_shared<MyTask>());
// pool.SubmitTask(std::make_shared<MyTask>());
// pool.SubmitTask(std::make_shared<MyTask>());
// pool.SubmitTask(std::make_shared<MyTask>());
// pool.SubmitTask(std::make_shared<MyTask>());
std::this_thread::sleep_for(std::chrono::seconds(5));
}

执行结果

1
2
3
4
5
6
7
8
9
10
11
12
ubuntu@ubuntu$ g++ -std=c++17 main.cpp thread_pool.cpp -lpthread
ubuntu@ubuntu$ ./a.out
tid: 140636074219264 ***begin***
tid: 140636074219264 ***end***
tid: 140636082611968 ***begin***
tid: 140636082611968 ***end***
tid: 140636065826560 ***begin***
tid: 140636091004672 ***begin***
tid: 140636091004672 ***end***
tid: 140636065826560 ***end***
value1 + value2 + value3: 4495500
value4: 4495500

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#ifndef THREAD_POOL_H
#define THREAD_POOL_H

#include <vector>
#include <queue>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <atomic>
#include <thread>

// Any类型:可以接收任意类型数据
class Any {
public:
Any() = default;
~Any() = default;
Any(Any&&) = default;
Any& operator=(Any&&) = default;

template <typename T>
Any(T data) : basePtr_(std::make_unique<Derive<T>>(data)) {}

// 这个方法能把Any对象里面存储的data数据提取出来
template <typename T>
T cast()
{
Derive<T>* devicePtr = dynamic_cast<Derive<T>*>(basePtr_.get());
if (devicePtr == nullptr) {
throw "error: type is unmatch.";
}
return devicePtr->data_;
}

private:
// 基类类型
class Base
{
private:
public:
virtual ~Base() = default;
};

// 派生类类型
template <typename T>
class Derive : public Base
{
public:
Derive(T data) : data_(data)
{}
T data_; // 保存了任意类型数据
};

private:
std::unique_ptr<Base> basePtr_;
};

// 实现一个信号量类
class Semaphore {
public:
Semaphore() = default;
Semaphore(int limit) : resLimit_(limit) {};
~Semaphore() = default;
Semaphore& operator=(Semaphore&&) = default;
Semaphore& operator=(const Semaphore&) = default;

void Wait()
{
std::unique_lock<std::mutex> lock(mtx_);
cond_.wait(lock, [&]()->bool { return resLimit_ > 0; });
resLimit_--;
}

void Post()
{
std::unique_lock<std::mutex> lock(mtx_);
resLimit_++;
cond_.notify_all();
}

private:
std::mutex mtx_;
std::condition_variable cond_;
int resLimit_ = 0;
};

class Task;

// 实现接收提交到线程的task任务执行完任务的返回值类型 Result
class Result {
public:
Result(std::shared_ptr<Task> task, bool isValid = true);
~Result() = default;
Result(const Result&) = default;
Result& operator=(const Result&) = default;
Result(Result&& result) = default;
Result& operator=(Result&& result) = default;
// SetValue方法 获取任务的返回值并设置到any_成员变量上
void SetValue(Any any);

// GetValue方法 用户调用该方法获取task的返回值
Any GetValue();

private:
Any any_; // 存储任务的返回值
Semaphore sem_; // 线程通信信号量
std::shared_ptr<Task> task_; // 指向对应获取返回值的所属任务
std::atomic_bool isValid_;
};

// 任务抽象基类
class Task {
public:
Task();
~Task() = default;
virtual Any Run() = 0;
void Exec();
void SetResult(Result* res);

private:
Result* resultPtr_; // Result对象的生命周期晚于Task对象
};

class Thread {
public:
using ThreadFunc = std::function<void()>;
Thread(ThreadFunc func);
~Thread();
void Start();

private:
ThreadFunc func_;
};

/* 线程模式
* fixed 线程数量固定
* cached 线程数量不固定
*/
enum class PoolMode {
MODE_FIXED,
MODE_CACHED
};

class ThreadPool {
public:
ThreadPool();
~ThreadPool();
void Start(int initThreadSize = 4);
void SetMode(PoolMode mode);
// void SetInitThreadSize();
void SetTaskThreadMaxHold(size_t size);
Result SubmitTask(std::shared_ptr<Task> task);
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;

private:
// 线程函数
void ThreadFunc();

private:
std::vector<Thread*> threads_;
size_t initThreadSize_; // 初始线程数量
std::queue<std::shared_ptr<Task>> taskQue_; // 任务队列
std::atomic_int taskSize_; // 任务数量
size_t taskThreadSizeMaxHold_; // 最大任务数量
std::mutex taskMtx_;
std::condition_variable notFull_; // 表示任务队列不满
std::condition_variable notEmpty_; // 表示任务队列不空
PoolMode mode_; // 当前线程池的工作模式
};

#endif // THREAD_POOL_H
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "thread_pool.h"
#include <iostream>

const size_t TASK_THREAD_SIZE_MAX_HOLD = 1024;

ThreadPool::ThreadPool()
: taskSize_(0),
taskThreadSizeMaxHold_(TASK_THREAD_SIZE_MAX_HOLD),
mode_(PoolMode::MODE_FIXED) {}

ThreadPool::~ThreadPool() {}

void ThreadPool::SetMode(PoolMode mode)
{
mode_ = mode;
}

void ThreadPool::SetTaskThreadMaxHold(size_t size)
{
taskThreadSizeMaxHold_ = size;
}

void ThreadPool::Start(int initThreadSize)
{
initThreadSize_ = initThreadSize;

for (int i = 0; i < initThreadSize_; i++) {
threads_.emplace_back(new Thread(std::bind(&ThreadPool::ThreadFunc, this)));
}

for (int i = 0; i < initThreadSize_; i++) {
threads_[i]->Start();
}
}

Result ThreadPool::SubmitTask(std::shared_ptr<Task> task)
{
std::unique_lock<std::mutex> lock(taskMtx_);

// 线程通信 等待任务队列有空余 wait wait_for wait_until
// 用户提交任务 最场阻塞时间不能超过1s,否则提交任务失败,返回
if(!notFull_.wait_for(lock, std::chrono::seconds(1),
[&]()->bool { return taskQue_.size() < taskThreadSizeMaxHold_;})) {
// 表示等待1s后,条件依然没有满足
std::cerr << "task queue is full, submit task fail." << std::endl;
return Result(task, false);
}

taskQue_.emplace(task);
taskSize_++;
notEmpty_.notify_all();
return Result(task);
}

// 定义线程函数 线程池的所有线程从任务队列里消费任务
void ThreadPool::ThreadFunc()
{
// std::cout << "Thread begin, tid: " << std::this_thread::get_id() << std::endl;
// std::cout << "Thread end, tid: " << std::this_thread::get_id() << std::endl;

for(;;) {
std::shared_ptr<Task> task;
// 通过作用域{}来释放锁
{
// 先获取锁
std::unique_lock<std::mutex> lock(taskMtx_);

// 等待notEmpty条件变量
notEmpty_.wait(lock, [&]()->bool { return !taskQue_.empty(); });

// 从任务队列中取一个任务出来
task = taskQue_.front();
taskQue_.pop();
taskSize_--;

// 如果依然有剩余任务,继续通知其它线程执行任务
if (taskQue_.size() > 0) {
notEmpty_.notify_all();
}

// 取出一个任务,通知任务队列已不满
notFull_.notify_all();
}

// 当前线程负责执行这个任务
// 【重要】:锁不要等到任务执行完再释放,所以上面用作用域{}及时释放锁
if (task != nullptr) {
// task->Run();
task->Exec();
}

}
}

Thread::Thread(ThreadFunc func) : func_(func)
{}

Thread::~Thread() {}

void Thread::Start()
{
std::thread t(func_);
t.detach();
}

/// Task 方法实现
Task::Task() : resultPtr_(nullptr)
{}

void Task::Exec()
{
if (resultPtr_ != nullptr) {
resultPtr_->SetValue(Run()); // run发生多态调用
}
}

void Task::SetResult(Result* res)
{
resultPtr_ = res;
}

/// Result 方法实现
Result::Result(std::shared_ptr<Task> task, bool isValid) : task_(task), isValid_(isValid)
{
task_->SetResult(this);
}

Any Result::GetValue()
{
if (!isValid_) {
return " ";
}
sem_.Wait(); // task任务如果没有执行完毕,这里会阻塞用户线程
return std::move(any_);
}

void Result::SetValue(Any any)
{
any_ = std::move(any);
sem_.Post();
}

总结

线程池ThreadPool 提供SubmitTask接口用于提交任务。SubmitTask接口内部通过wait_for等待条件变量notFull_条件成立(条件为任务队列不满),等待1s后条件仍然不成立,打印任务提交失败。如果任务队列不满,提交的任务会放到任务队列中。

线程池ThreadPool 预先执行Start接口,Start接口内部会创建线程对象,并绑定线程函数为线程池ThreadPool 的成员函数ThreadFunc。

线程池ThreadPool 的成员函数ThreadFunc内部,会获取锁,通过wait等待条件变量notEmpty_条件成立(条件为任务队列不空),如果不空,从任务队列中取一个任务,在任务执行前释放锁(否则会阻塞其它线程获取锁,失去线程池的意义)。

add:

线程池ThreadPool的SubmitTask接口返回值是Result,通过task指针构造Result对象 Result(task),并在Result构造函数中调用

task_->SetResult(this); 将Result和task联系起来,线程执行完任务,会通过task向Result设置返回值SetValue,SetValue内部会调用信号量的Post()释放锁。

调用方调用result.GetValue()获取返回值,GetValue()内部会调用信号量sem_.Wait(),如果任务还未执行完毕,会阻塞等待SetValue内部会调用信号量的Post()释放锁。

GetValue阻塞等待可能对调用方不友好,可以增加回调函数通知任务已执行完毕。或者在Result中增加标志/函数(用于判断任务是否执行完毕),从而不在GetValue()中进行阻塞等待。