// ---------------------------------------------------------------------------- // // Matrix multiplication (multithreaded client/server version). // (c) Wolfgang Schreiner, 2005. // // ---------------------------------------------------------------------------- import java.io.*; import java.net.*; class MatMultNet { // name and port of server final static String URL = "localhost"; final static int port = 9999; // matrix dimension final static int N = 1024; // number of threads static int T; // shared data used by threads static float[][] A0; static float[][] B0; static float[][] C0; static class MultThread extends Thread { int begin; // index of first row int end; // index of last row + 1 // ------------------------------------------------------------------------- // // Create thread in charge of filling C[begin..end-1] // // ------------------------------------------------------------------------- public MultThread(int begin, int end) { this.begin = begin; this.end = end; } // ---------------------------------------------------------------------------- // // multiply A by B giving C // // ---------------------------------------------------------------------------- public void run() { for (int i = begin; i < end; i++) { for (int j = 0; j < N; j++) { float m = 0; for (int k = 0; k < N; k++) m = m + A0[i][k]*B0[k][j]; C0[i][j] = m; } } } } // ---------------------------------------------------------------------------- // // multiply A by B giving C // // ---------------------------------------------------------------------------- static void multiply(float[][] A, float[][] B, float[][] C) { A0 = A; B0 = B; C0 = C; int n = N/T; MultThread[] thread = new MultThread[T]; for (int i = 0; i < T; i++) { if (i+1 < T) thread[i] = new MultThread(i*n, i*n+n); else thread[i] = new MultThread(i*n, N); thread[i].start(); } try { for (int i = 0; i < T; i++) thread[i].join(); } catch(InterruptedException e) { } } // ---------------------------------------------------------------------------- // // main program // // ---------------------------------------------------------------------------- public static void main(String[] args) { if (args.length != 1) { System.out.println("Usage: MatMultThreads -client | -server"); return; } if (args[0].equals("-client")) client(); else server(); } // ---------------------------------------------------------------------------- // // client code // // ---------------------------------------------------------------------------- static void client() { try { BufferedReader console = new BufferedReader(new InputStreamReader(System.in)); while (true) { String line = console.readLine(); if (line == null) return; Socket socket = new Socket(URL, port); BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); PrintWriter out = new PrintWriter(new OutputStreamWriter(socket.getOutputStream())); out.println(line); out.flush(); String answer = in.readLine(); if (answer == null) System.out.println("empty result"); else System.out.println("answer: " + answer); } } catch(IOException e) { System.out.println("Exception: " + e); System.exit(-1); } } // ---------------------------------------------------------------------------- // // server code // // ---------------------------------------------------------------------------- static void server() { try { ServerSocket server = new ServerSocket(port); while (true) { Socket socket = server.accept(); BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); PrintWriter out = new PrintWriter(new OutputStreamWriter(socket.getOutputStream())); try { T = Integer.parseInt(in.readLine()); } catch(NumberFormatException e) { out.println("Invalid argument: " + e.getMessage()); out.flush(); continue; } System.out.println("Started"); float[][] A = new float[N][N]; float[][] B = new float[N][N]; float[][] C = new float[N][N]; multiply(A, B, C); System.out.println("Done"); out.println("Done"); out.flush(); } } catch(IOException e) { System.out.println("Exception: " + e); System.exit(-1); } } }