a03ef0b32757085673185bdec436ac88f9a48188
[ovirt-viewer.git] / tunnel.c
1 /* ovirt viewer console application
2  * Copyright (C) 2008 Red Hat Inc.
3  * Written by Mohammed Morsi <mmorsi@redhat.com>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 2 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18  */
19
20 /* ovirt-viewer starts listening on network port to 
21  *  encapsulate vnc packets including the vm's name
22  *  so as to be able to be proxied. 
23  *
24  * This operation takes place in another thread
25  * which can be started/stopped by calling 
26  * start_tunnel / stop_tunnel. 
27  *
28  * An additional connection thread is created and maintained
29  * internally for each vm / vnc connection open in ovirt-viewer
30  * establishing a connection w/ the ovirt server.
31  */
32
33 #include <sys/types.h>
34 #include <sys/socket.h>
35 #include <stdlib.h>
36 #include <stdio.h>
37 #include <netinet/in.h>
38 #include <arpa/inet.h>
39 #include <unistd.h>
40 #include <assert.h>
41 #include <string.h>
42
43 #include <glib.h>
44
45 #include "internal.h"
46
47 /* constants */
48
49 // port to try to listen on, if we can't, increment until we find one we can
50 const int PORT_RANGE_START = 5600;
51
52 // max length of a vm name
53 const int VM_NAME_MAX_LEN = 250;
54
55 // max length of vnc data
56 const int VNC_DATA_MAX_LEN = 800000;
57
58 // port which to connect to on ovirt server
59 const int OVIRT_SERVER_PORT = 5900;
60
61 /* Private thread functions */
62 static gpointer tunnel_thread(gpointer data);
63 static gpointer client_server_thread(gpointer data);
64 static gpointer server_client_thread(gpointer data);
65
66 /* Other private functions */
67 static void close_socket(gpointer _socket, gpointer data);
68 static void wait_for_thread(gpointer _thread, gpointer data);
69
70 /* tunnel and main threads */
71 static GThread *tunnel_gthread = NULL;
72 static GThread *main_gthread = NULL;
73
74 /* list of communication threads */
75 static GSList *communication_threads = NULL;
76
77 /* list of sockets */
78 static GSList *sockets = NULL;
79
80 /* thread termination flag */
81 static gboolean run_tunnel = FALSE;
82
83 /*  internal.h shared constructs */
84 int tunnel_port;
85
86 /////////////////
87
88 /** public implementations **/
89
90 /* start tunnel thread */
91 void 
92 start_tunnel(void)
93 {
94   GError *error = NULL;
95
96   DEBUG ("starting the tunnel thread");
97
98   assert (tunnel_gthread == NULL);
99
100   run_tunnel = TRUE;
101
102   main_gthread = g_thread_self ();
103
104   tunnel_gthread = g_thread_create (tunnel_thread, NULL, TRUE, &error);
105   if (error) {
106     g_print ("%s\n", error->message);
107     g_error_free (error);
108     exit (1);
109   }
110 };
111
112 /* stop tunnel thread */
113 void 
114 stop_tunnel(void)
115 {
116   if(!run_tunnel)
117       return;
118
119   DEBUG ("stopping the tunnel thread");
120
121   assert (tunnel_gthread != NULL);
122   ASSERT_IS_MAIN_THREAD ();
123
124   run_tunnel = FALSE;
125
126   g_slist_foreach(sockets, close_socket, NULL);
127
128   (void) g_thread_join (tunnel_gthread);
129   tunnel_gthread = NULL;
130 };
131
132 /////////////////
133
134 /** private implementations **/
135
136 /* the tunnel thread */
137 static gpointer
138 tunnel_thread (gpointer _data)
139 {
140   //char vm_data[VM_NAME_MAX_LEN];
141   int local_server_socketfd, ovirt_server_socket, client_socketfd;
142   unsigned int local_server_len, client_len, ovirt_server_len;
143
144   struct sockaddr_in local_server_address;
145   struct sockaddr_in ovirt_server_address;
146   struct sockaddr_in client_address;
147
148   GThread *client_server_gthread = NULL;
149   GThread *server_client_gthread = NULL;
150
151   int sockets_param[2];
152   int * c_socket;
153
154   DEBUG ("tunnel thread starting up");
155
156   // ovirt server address
157   ovirt_server_address.sin_family = PF_INET;
158   ovirt_server_address.sin_addr.s_addr = inet_addr(hostname);
159   ovirt_server_address.sin_port = htons(OVIRT_SERVER_PORT);
160   ovirt_server_len = sizeof(ovirt_server_address);
161
162   // create local net socket
163   local_server_socketfd = socket(PF_INET, SOCK_STREAM, 0);
164   c_socket = malloc(sizeof(int)); *c_socket = local_server_socketfd;
165   sockets = g_slist_prepend(sockets, c_socket);
166
167   // local server address
168   tunnel_port = PORT_RANGE_START;
169   local_server_address.sin_family = PF_INET;
170   local_server_address.sin_addr.s_addr = inet_addr("127.0.0.1");
171   local_server_address.sin_port = htons(tunnel_port);
172   local_server_len = sizeof(local_server_address);
173
174   // increment ports until one is available
175   while(bind(local_server_socketfd, (struct sockaddr*)&local_server_address, local_server_len) < 0){
176      tunnel_port += 1;
177      local_server_address.sin_port += htons(tunnel_port);
178   }
179
180   DEBUG ("tunnel bound to local port %i", tunnel_port);
181
182   // increase client buffer size?
183   listen(local_server_socketfd, 5);
184
185   while(run_tunnel) {
186      // accept a client connection
187      DEBUG("tunnel accepting");
188      client_len = sizeof(client_address);
189      client_socketfd = accept(local_server_socketfd, (struct sockaddr*)&client_address, &client_len);
190      if(client_socketfd < 0){
191          DEBUG("tunnel accept failed");
192          break;
193      }
194      // TODO check accept return value for err
195      c_socket = malloc(sizeof(int)); *c_socket = client_socketfd;
196      sockets = g_slist_prepend(sockets, c_socket);
197
198      DEBUG ("client connected to tunnel");
199
200      // establish connection w/ ovirt server
201      ovirt_server_socket = socket(PF_INET, SOCK_STREAM, 0);
202      c_socket = malloc(sizeof(int)); *c_socket = ovirt_server_socket;
203      sockets = g_slist_prepend(sockets, c_socket);
204      DEBUG ("connecting to ovirt server %s on %i", hostname, OVIRT_SERVER_PORT);
205      if(connect(ovirt_server_socket, (struct sockaddr*)&ovirt_server_address, ovirt_server_len) < 0){
206        DEBUG ("could not connect to ovirt server");
207        break;
208        //return NULL;
209      }
210      DEBUG ("connected to ovirt server");
211
212      sockets_param[0]  = ovirt_server_socket; 
213      sockets_param[1]  = client_socketfd;
214
215      // launch thread for client -> server traffic
216      client_server_gthread = g_thread_create (client_server_thread, 
217                                               &sockets_param, TRUE, NULL);
218
219      // launch thread for server -> client traffic
220      server_client_gthread = g_thread_create (server_client_thread, 
221                                               &sockets_param, TRUE, NULL);
222
223      communication_threads = g_slist_prepend(communication_threads, client_server_gthread);
224      communication_threads = g_slist_prepend(communication_threads, server_client_gthread);
225
226      // send target vm for this session
227      //strcpy(vm_data, vm_in_focus->description);
228      DEBUG ("sending vm %s", vm_in_focus->description);
229      write(ovirt_server_socket, vm_in_focus->description, strlen(vm_in_focus->description));
230
231   }
232
233   DEBUG("terminating tunnel thread");
234
235   // wait for connection threads to finish
236   g_slist_foreach(communication_threads, wait_for_thread, NULL);
237
238   DEBUG ("tunnel thread completed");
239   return NULL;
240 };
241
242 /* the tunnel thread */
243 static gpointer
244 client_server_thread (gpointer _data){
245   int nbytes;
246   char vnc_data[VNC_DATA_MAX_LEN];
247
248   int ovirt_server_socket = ((int*)_data)[0],
249       client_socket = ((int*)_data)[1];
250
251   DEBUG ("client/server thread starting up");
252   
253   while(run_tunnel){
254     VERBOSE( "accepting client data");
255
256     // grab vnc data
257     nbytes = read(client_socket, vnc_data, VNC_DATA_MAX_LEN);
258     if(nbytes <= 0){
259       DEBUG ( "error reading data from client" );
260       break;
261     }
262     VERBOSE ("read %i bytes from client", nbytes);
263         
264     // send network_data onto server
265     nbytes = write(ovirt_server_socket, vnc_data, nbytes);
266     if(nbytes <= 0){
267       DEBUG ( "error writing data to server" );
268       break;
269     }
270     VERBOSE ("wrote %i bytes to server", nbytes);
271   }
272
273   DEBUG ("client/server thread completed");
274   return NULL;
275 };
276
277 /* the server thread */
278 static gpointer
279 server_client_thread (gpointer _data){
280   char vnc_data[VNC_DATA_MAX_LEN];
281
282   int ovirt_server_socket = ((int*)_data)[0],
283       client_socket = ((int*)_data)[1];
284
285   int nbytes;
286
287   DEBUG ("server/client thread starting up");
288   
289   while(run_tunnel){
290      // grab vnc data
291      nbytes = read(ovirt_server_socket, vnc_data, VNC_DATA_MAX_LEN);
292      if(nbytes <= 0){
293        DEBUG ( "error reading data from server" );
294        break;
295      }
296     VERBOSE ("read %i bytes from server", nbytes);
297
298     // send network_data onto client
299     nbytes = write(client_socket, vnc_data, nbytes);
300     if(nbytes <= 0){
301       DEBUG ( "error writing data to client" );
302       break;
303     }
304     VERBOSE ("wrote %i bytes to client", nbytes);
305   }
306
307   DEBUG ("server/client thread completed");
308   return NULL;
309 };
310
311 static void close_socket(gpointer _socket, gpointer data){
312   shutdown(*(int*) _socket, 2);
313   close(*(int*) _socket);
314   free((int*) _socket);
315 };
316
317 static void wait_for_thread(gpointer _thread, gpointer data){
318   g_thread_join((GThread*)_thread);
319 };