1 ///
2 module modbus.connection.tcp;
3 
4 import std.conv : to;
5 import std.datetime.stopwatch;
6 import std.exception : enforce;
7 import std.socket;
8 public import std.socket : Address, InternetAddress, Internet6Address;
9 version (Posix) public import std.socket : UnixAddress;
10 
11 import modbus.exception;
12 import modbus.connection.base;
13 import modbus.msleep;
14 
15 ///
16 abstract class TcpConnectionBase : AbstractConnection
17 {
18 protected:
19     Socket sock;
20 
21     void delegate(Duration) sleepFunc;
22 
23     void sleep(Duration d)
24     {
25         if (sleepFunc !is null) sleepFunc(d);
26         else msleep(d);
27     }
28 
29     Duration _writeStepSleep = 10.usecs;
30     Duration _readStepSleep = 10.usecs;
31 
32     import core.stdc.errno;
33 
34     void m_write(const(void)[] buf)
35     {
36         sock.blocking = false;
37         size_t written;
38         const sw = StopWatch(AutoStart.yes);
39         while (sw.peek < _wtm)
40         {
41             auto res = sock.send(buf[written..$]);
42             if (res == 0) // connection is closed
43                 throwCloseTcpConnection("write");
44             if (res == Socket.ERROR)
45             {
46                 if (wouldHaveBlocked) res = 0;
47                 else throwModbusException("TCP Socket send error " ~ sock.getErrorText);
48             }
49             written += res;
50             if (written == buf.length) return;
51             this.sleep(_writeStepSleep);
52         }
53         throwTimeoutException(sock.to!string, "write timeout");
54     }
55 
56     void[] m_read(void[] buf, CanRead cr)
57     {
58         sock.blocking = false;
59         size_t readed;
60         const sw = StopWatch(AutoStart.yes);
61         auto ss = new SocketSet;
62         ss.add(sock);
63         while (sw.peek < _rtm)
64         {
65             auto res = sock.receive(buf[readed..$]);
66             if (res == 0) // connection is closed
67                 throwCloseTcpConnection("read");
68             if (res == Socket.ERROR)
69             {
70                 if (wouldHaveBlocked) res = 0;
71                 else throwModbusException("TCP Socket receive error " ~ sock.getErrorText);
72             }
73             readed += res;
74             if (readed == buf.length) return buf[];
75             this.sleep(_readStepSleep);
76         }
77         if (cr == CanRead.allOrNothing || (cr == CanRead.anyNonZero && !readed))
78             throwTimeoutException(sock.to!string, "read timeout");
79         return buf[0..readed];
80     }
81 
82 public:
83 
84     ///
85     inout(Socket) socket() inout @property { return sock; }
86 
87     ///
88     bool isAlive() @property { return sock.isAlive; }
89 
90     ///
91     void close()
92     {
93         sock.shutdown(SocketShutdown.BOTH);
94         sock.close();
95     }
96 }
97 
98 /// Client
99 class MasterTcpConnection : TcpConnectionBase
100 {
101 protected:
102     Address addr;
103 
104 public:
105 
106     ///
107     this(Address addr, void delegate(Duration) sleepFunc=null)
108     {
109         this.addr = addr;
110         this.sleepFunc = sleepFunc;
111     }
112 
113     protected void initSock()
114     {
115         sock = new TcpSocket(addr.addressFamily);
116         sock.blocking = true;
117         sock.connect(addr);
118         sock.blocking = false;
119     }
120 
121 override:
122 
123     void write(const(void)[] msg)
124     {
125         if (sock is null) initSock();
126         m_write(msg);
127     }
128 
129     void[] read(void[] buf, CanRead cr=CanRead.allOrNothing)
130     {
131         if (sock is null) initSock();
132         return m_read(buf, cr);
133     }
134 
135     void reconnect()
136     {
137         close();
138         initSock();
139     }
140 
141     void close()
142     {
143         if (sock is null) return;
144         super.close();
145         sock = null;
146     }
147 }
148 
149 /// slave connection
150 class SlaveTcpConnection : TcpConnectionBase
151 {
152     ///
153     this(Socket s, void delegate(Duration) sf=null)
154     {
155         if (s is null)
156             throwModbusException("TCP Socket is null");
157         sleepFunc = sf;
158         sock = s;
159         sock.blocking = false;
160     }
161 
162 override:
163     void write(const(void)[] msg) { m_write(msg); }
164 
165     void[] read(void[] buf, CanRead cr=CanRead.allOrNothing)
166     { return m_read(buf, cr); }
167 
168     void reconnect() { assert(0, "not allowed for SlaveTcpConnection"); }
169 }
170 
171 version (unittest): package(modbus):
172 
173 import modbus.ut;
174 
175 class CFCSlave : Fiber
176 {
177     SlaveTcpConnection con;
178 
179     void[] result, data;
180     size_t id;
181     bool terminate, inf;
182 
183     this(Socket sock, size_t id, size_t dlen, bool inf=false)
184     {
185         this.id = id;
186         this.inf = inf;
187         con = new SlaveTcpConnection(sock);
188         data = new void[](dlen);
189         con.readTimeout = 1.seconds;
190         super(&run);
191     }
192 
193     void run()
194     {
195         testPrintf!("slave #%d start read")(id);
196         while (result.length < data.length || inf)
197         {
198             result ~= con.read(data, con.CanRead.zero);
199             testPrintf!("slave #%d readed %d")(id, result.length);
200         }
201         testPrintf!("slave #%d finish read (%d)")(id, result.length);
202 
203         con.sleep(uniform(1, 20).msecs);
204         con.write([id]);
205         testPrintf!("slave #%d finish")(id);
206     }
207 }
208 
209 class CFSlave : Fiber
210 {
211     TcpSocket serv;
212     CFCSlave[] cons;
213     size_t dlen;
214     SocketSet ss;
215     bool inf;
216 
217     this(Address addr, int cc, size_t dlen, bool inf=false)
218     {
219         this.dlen = dlen;
220         this.inf = inf;
221         serv = new TcpSocket;
222         serv.blocking = true;
223         serv.bind(addr);
224         serv.listen(cc);
225         serv.blocking = false;
226         ss = new SocketSet;
227         super(&run);
228     }
229 
230     void run()
231     {
232         while (true)
233         {
234             scope (exit) yield();
235             ss.reset();
236             ss.add(serv);
237 
238             while (Socket.select(ss, null, null, Duration.zero))
239             {
240                 cons ~= new CFCSlave(serv.accept(), cons.length, dlen, inf);
241                 testPrintf!"new client, create slave #%d"(cons.length-1);
242                 yield();
243             }
244 
245             foreach (c; cons.filter!(a=>a.state != a.State.TERM && !a.terminate))
246             {
247                 try c.call;
248                 catch (CloseTcpConnection)
249                 {
250                     testPrintf!"close slave #%d connection (%d bytes received)"(c.id, c.result.length);
251                     c.con.close();
252                     c.terminate = true;
253                 }
254                 c.con.sleep(uniform(1,5).msecs);
255             }
256 
257             if (cons.length && cons.all!(a=>a.state == a.State.TERM || a.terminate))
258             {
259                 testPrint("server finished");
260                 break;
261             }
262         }
263     }
264 }
265 
266 class CFMaster : Fiber
267 {
268     MasterTcpConnection con;
269     size_t id;
270     size_t serv_id;
271     bool noread;
272 
273     void[] data;
274 
275     this(Address addr, size_t id, size_t dlen, bool noread=false)
276     {
277         this.id = id;
278         this.noread = noread;
279         con = new MasterTcpConnection(addr);
280         data = new void[](dlen);
281         foreach (ref v; cast(ubyte[])data)
282             v = uniform(ubyte(0), ubyte(128));
283         super(&run);
284     }
285 
286     void run()
287     {
288         con.sleep(uniform(1, 50).msecs);
289         size_t writted;
290         while (writted != data.length)
291         {
292             auto cn = uniform(writted+1, data.length+1);
293             con.write(data[writted..cn]);
294             testPrintf!"master #%d writted %d"(id, cn);
295             writted = cn;
296             con.sleep(uniform(1, 10).msecs);
297         }
298         testPrintf!"master #%d send data"(id);
299         con.readTimeout = 2000.msecs;
300         con.sleep(uniform(1, 50).msecs);
301 
302         if (noread)
303         {
304             testPrintf!"close master #%d connection"(id);
305             con.socket.shutdown(SocketShutdown.BOTH);
306             con.socket.close();
307         }
308         else
309         {
310             void[24] tmp = void;
311             testPrintf!"master #%d start receive"(id);
312             serv_id = (cast(size_t[])con.read(tmp[], con.CanRead.anyNonZero))[0];
313             testPrintf!"master #%d receive serv id #%d"(id, serv_id);
314         }
315     }
316 }
317 
318 unittest
319 {
320     mixin(mainTestMix);
321     ut!simpleFiberTest(new InternetAddress("127.0.0.1", 8091));
322     ut!closeSocketTest(new InternetAddress("127.0.0.1", 8092));
323 }
324 
325 void simpleFiberTest(Address addr)
326 {
327     enum BS = 512;
328     enum N = 12;
329     auto cfs = new CFSlave(addr, N, BS);
330     scope(exit) cfs.serv.close();
331     CFMaster[] cfm;
332     foreach (i; 0 .. N)
333         cfm ~= new CFMaster(addr, i, BS);
334     scope(exit) cfm.each!(a=>a.con.sock.close());
335 
336     bool work = true;
337     int step;
338     while (work)
339     {
340         alias TERM = Fiber.State.TERM;
341         if (cfs.state != TERM) cfs.call;
342         foreach (c; cfm.filter!(a=>a.state != TERM)) c.call;
343 
344         step++;
345         Thread.sleep(5.msecs);
346         if (cfm.all!(a=>a.state == TERM) && cfs.state == TERM)
347         {
348             enforce(cfs.cons.length == N, "no server connections");
349             foreach (i; 0 .. N)
350             {
351                 auto mm = cfm[i];
352                 auto id = mm.serv_id;
353                 auto ss = cfs.cons[id];
354                 if (ss.result.length == mm.data.length)
355                 {
356                     enforce(equal(cast(ubyte[])ss.result, cast(ubyte[])mm.data));
357                     work = false;
358                 }
359                 else throw new Exception(text(ss.result, " != ", mm.data));
360             }
361             testPrintf!"basic loop steps: %s"(step);
362         }
363     }
364 }
365 
366 void closeSocketTest(Address addr)
367 {
368     enum BS = 512;
369     enum N = 12;
370     auto cfs = new CFSlave(addr, N, BS, true);
371     scope(exit) cfs.serv.close();
372     CFMaster[] cfm;
373     foreach (i; 0 .. N)
374         cfm ~= new CFMaster(addr, i, BS, true);
375 
376     bool work = true;
377     while (work)
378     {
379         alias TERM = Fiber.State.TERM;
380         if (cfs.state != TERM) cfs.call;
381         foreach (c; cfm.filter!(a=>a.state != TERM)) c.call;
382 
383         Thread.sleep(5.msecs);
384         if (cfm.all!(a=>a.state == TERM) && cfs.state == TERM)
385         {
386             work = false;
387             assert(cfs.cons.all!(a=>a.terminate));
388             assert(cfs.cons.all!(a=>!a.con.isAlive));
389             assert(cfs.cons.all!(a=>a.result.length == BS));
390         }
391     }
392 }