TazGraph Project v0.1.0
Loading...
Searching...
No Matches
Threader.h
1#pragma once
2#include <functional>
3#include <queue>
4#include <vector>
5#include <thread>
6#include <mutex>
7#include <atomic>
8#include <condition_variable>
9
10struct TaskQueue {
11 std::deque<std::function<void()>> tasks;
12 std::mutex mutex_;
13 std::condition_variable taskCondition;
14 std::atomic<int> remaining_tasks = 0;
15 bool shuttingDown = false;
16 void addTask(std::function<void()>&& callback) {
17 {
18 std::lock_guard<std::mutex> lock(mutex_);
19 tasks.emplace_back(std::move(callback));
20 }
21 remaining_tasks++;
22 taskCondition.notify_one();
23 }
24
25 bool getTask(std::function<void()>& task) {
26 std::unique_lock<std::mutex> lock(mutex_);
27 taskCondition.wait(lock, [this] { return !tasks.empty() || shuttingDown; });
28
29 if (shuttingDown || tasks.empty()) return false; // Shouldn't happen, but just in case
30
31 task = std::move(tasks.front());
32 tasks.pop_front();
33 return true;
34 }
35
36 void waitUntilDone() const {
37 while (remaining_tasks > 0) {
38 std::this_thread::yield();
39 }
40 }
41
42 void completeTask() {
43 remaining_tasks--;
44 }
45};
46
47struct Thread {
48 int id = 0;
49 std::thread cur_thread;
50 std::function<void()> task = nullptr;
51 bool running = true;
52 TaskQueue* t_queue = nullptr;
53
54 Thread(TaskQueue& task_queue_, int id_)
55 : id{ id_ }
56 , t_queue{ &task_queue_ }
57 {
58 cur_thread = std::thread([this]() {
59 run();
60 });
61 }
62
63 void run() {
64 while (running) {
65 std::function<void()> task;
66 if (t_queue->getTask(task)) {
67 task();
68 t_queue->completeTask();
69 }
70 }
71 }
72
73 void stop() {
74 running = false;
75 t_queue->shuttingDown = true;
76 t_queue->taskCondition.notify_all();
77 if (cur_thread.joinable()) {
78 cur_thread.join();
79 }
80 }
81
82};
83
84struct Threader {
85 TaskQueue t_queue;
86 int num_threads = 1;
87 std::vector<Thread> threads;
88
89 Threader(int num_threads_)
90 : num_threads{ num_threads_ }
91 {
92 threads.reserve(num_threads_);
93 for (int i = 0; i < num_threads_; i++)
94 threads.emplace_back(t_queue, i);
95 }
96
97 void parallel(int num_obj, std::function<void(int start, int end)>&& callback) {
98 if (num_obj == 0) return;
99 int slice_size = num_obj / num_threads;
100 for (int i = 0; i < num_threads; i++) {
101 int start = i * slice_size;
102 int end = start + slice_size;
103 t_queue.addTask([start, end, &callback]() { callback(start, end);});
104 }
105 if (slice_size * num_threads < num_obj) {
106 int start = slice_size * num_threads;
107 callback(start, num_obj);
108 }
109 //todo this may be done only at specific times
110 t_queue.waitUntilDone();
111 }
112};
Definition Threader.h:10
Definition Threader.h:47
Definition Threader.h:84