
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public class CrashTest {


	static int errorCounter;
	static int global_key=0;
	
	//TODO ADD IN YOUR MYSQL CONNECT STRING
	static String connectString="jdbc:mysql://your.host.com:3306/test?user=user&password=password";


	static long totalTime=0;

	static synchronized String getNextKey() {
		return ""+global_key++;
	}


	static Connection getConnection() {
		

		Connection con=null;

		try {
			Class.forName("com.mysql.jdbc.Driver").newInstance(); 
			con = 
				DriverManager.getConnection(connectString);

			con.setAutoCommit(false);

		} catch (SQLException ex) {
			// handle any errors
			System.out.println("SQLException: " + ex.getMessage());
			System.out.println("SQLState: " + ex.getSQLState());
			System.out.println("VendorError: " + ex.getErrorCode());
		} catch (InstantiationException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IllegalAccessException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (ClassNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return con;
	}

	static void testSet(Connection con, int iters, String tableName) {

		PreparedStatement ps = null;
		int  rs = 0;
		long opTime=0;
		StringBuffer buf = new StringBuffer();
		for(int i=0; i< 1024*10; i++)
			buf.append("x");

		try {
			ps = con.prepareStatement("INSERT INTO " + tableName + " (row_key,data) values (?,?)");
			String key;
			int batchSize=100;
			boolean useBatch=false;
			for (int x =0; x < iters; x++ ) {
				long startTime=System.currentTimeMillis();
				//Do something with the Connection
				key = getNextKey();
				ps.setString(1,key) ;
				ps.setString(2, "1");

				if (useBatch) {
					ps.addBatch();
					if (x % batchSize==0) {
						ps.executeBatch();
					}
				} else {
					rs = ps.executeUpdate();
				}
				con.commit();
				opTime+=(System.currentTimeMillis()-startTime);
			}

		} catch (Exception e) {
			System.out.println(e);
		} finally {

		}
		totalTime+= opTime/iters;
	}

	static void testUpdate(Connection con, int iters, String tableName) {

		PreparedStatement ps = null;
		int  rs = 0;
		long opTime=0;
		StringBuffer buf = new StringBuffer();
		for(int i=0; i< 1024*10; i++)
			buf.append("z");

		String key = null;
		try {
			ps = con.prepareStatement("UPDATE " + tableName + " set data=? where row_key=?");

			int batchSize=100;
			boolean useBatch=false;
			for (int x =0; x < iters; x++ ) {
				long startTime=System.currentTimeMillis();
				//Do something with the Connection
				key = getNextKey();
				ps.setString(1, buf.toString());
				ps.setString(2,key) ;

				if (useBatch) {
					ps.addBatch();
					if (x % batchSize==0) {
						ps.executeBatch();
					}
				} else {
					rs = ps.executeUpdate();
				}

				con.commit();

				opTime+=(System.currentTimeMillis()-startTime);
			}
		} catch (Exception e) {
			System.out.println("testUpdate Error:" +e.getMessage() + " key= " + key);
		} finally {

		}
		totalTime+= opTime/iters;
	}


	static synchronized void updateErrors() {
		errorCounter++;
	}


	public static void main(String[] args) throws Exception { 
		final String arg;
		
		if (args.length > 0) { 
			 arg = args[0];
		} else {
			 arg="";
		}
		
		
		for(int numThreads = 4; numThreads > 0; numThreads--) {
			final int perThreadIterations=1000;

			Thread[] threads = new Thread[numThreads];
			long t0 = System.currentTimeMillis();
			totalTime=0;
			for (int i=0; i<threads.length; i++) { 
				threads[i] = new Thread(new Runnable() {
					public void run() {
						try {
							Connection con = getConnection();
							if (arg.equals("insert")) {
								testSet(con,perThreadIterations, "dt_1");
							} else if(arg.equals("update")) {
								testUpdate(con,perThreadIterations,"dt_1");
							} else {
								System.out.println("Please specify update or insert");
								System.exit(1);
							}
						}catch( Exception e){
							System.out.println("t" + e);

						}
					}
				});
				threads[i].start();
			}

			for (int i=0; i<threads.length; i++) { 
				threads[i].join();
			}

			t0 = System.currentTimeMillis() - t0;

			long avg_latency=totalTime/numThreads;


			t0 = t0/1000; // seconds

			long throughput;

			if (t0 != 0 ) {
				throughput=(3600*(numThreads * perThreadIterations)/t0)/1000;
			} else {
				throughput=-1;
			}

			System.out.println("time\tthreads\titers\ttotal\ttput(K)\tlat(ms)\terrors");
			System.out.println(t0 +"\t" + numThreads + "\t" + perThreadIterations +"\t" + 
					(numThreads*perThreadIterations) + "\t" + throughput + "\t"+ avg_latency + "\t" + 
					errorCounter );
		}
	}




}
