From f70a10e6bd3821bbc5ac4d0fc161225106610029 Mon Sep 17 00:00:00 2001 From: Bradlee Speice Date: Wed, 30 Nov 2016 17:42:06 -0500 Subject: [PATCH] Add model definitions --- Default Prediction.ipynb | 529 +++++++++++++++++++++------------------ 1 file changed, 284 insertions(+), 245 deletions(-) diff --git a/Default Prediction.ipynb b/Default Prediction.ipynb index a8d176f..0a8e6c4 100644 --- a/Default Prediction.ipynb +++ b/Default Prediction.ipynb @@ -22,121 +22,21 @@ ] }, { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "root\n", - " |-- activity: string (nullable = true)\n", - " |-- basket_amount: long (nullable = true)\n", - " |-- bonus_credit_eligibility: boolean (nullable = true)\n", - " |-- borrowers: array (nullable = true)\n", - " | |-- element: struct (containsNull = true)\n", - " | | |-- first_name: string (nullable = true)\n", - " | | |-- gender: string (nullable = true)\n", - " | | |-- last_name: string (nullable = true)\n", - " | | |-- pictured: boolean (nullable = true)\n", - " |-- currency_exchange_loss_amount: double (nullable = true)\n", - " |-- delinquent: boolean (nullable = true)\n", - " |-- description: struct (nullable = true)\n", - " | |-- languages: array (nullable = true)\n", - " | | |-- element: string (containsNull = true)\n", - " | |-- texts: struct (nullable = true)\n", - " | | |-- ar: string (nullable = true)\n", - " | | |-- en: string (nullable = true)\n", - " | | |-- es: string (nullable = true)\n", - " | | |-- fr: string (nullable = true)\n", - " | | |-- id: string (nullable = true)\n", - " | | |-- mn: string (nullable = true)\n", - " | | |-- pt: string (nullable = true)\n", - " | | |-- ru: string (nullable = true)\n", - " | | |-- vi: string (nullable = true)\n", - " |-- funded_amount: long (nullable = true)\n", - " |-- funded_date: string (nullable = true)\n", - " |-- id: long (nullable = true)\n", - " |-- image: struct (nullable = true)\n", - " | |-- id: long (nullable = true)\n", - " | |-- template_id: long (nullable = true)\n", - " |-- journal_totals: struct (nullable = true)\n", - " | |-- bulkEntries: long (nullable = true)\n", - " | |-- entries: long (nullable = true)\n", - " |-- lender_count: long (nullable = true)\n", - " |-- loan_amount: long (nullable = true)\n", - " |-- location: struct (nullable = true)\n", - " | |-- country: string (nullable = true)\n", - " | |-- country_code: string (nullable = true)\n", - " | |-- geo: struct (nullable = true)\n", - " | | |-- level: string (nullable = true)\n", - " | | |-- pairs: string (nullable = true)\n", - " | | |-- type: string (nullable = true)\n", - " | |-- town: string (nullable = true)\n", - " |-- name: string (nullable = true)\n", - " |-- paid_amount: double (nullable = true)\n", - " |-- paid_date: string (nullable = true)\n", - " |-- partner_id: long (nullable = true)\n", - " |-- payments: array (nullable = true)\n", - " | |-- element: struct (containsNull = true)\n", - " | | |-- amount: double (nullable = true)\n", - " | | |-- currency_exchange_loss_amount: double (nullable = true)\n", - " | | |-- local_amount: double (nullable = true)\n", - " | | |-- payment_id: long (nullable = true)\n", - " | | |-- processed_date: string (nullable = true)\n", - " | | |-- rounded_local_amount: double (nullable = true)\n", - " | | |-- settlement_date: string (nullable = true)\n", - " |-- planned_expiration_date: string (nullable = true)\n", - " |-- posted_date: string (nullable = true)\n", - " |-- sector: string (nullable = true)\n", - " |-- status: string (nullable = true)\n", - " |-- tags: array (nullable = true)\n", - " | |-- element: struct (containsNull = true)\n", - " | | |-- name: string (nullable = true)\n", - " |-- terms: struct (nullable = true)\n", - " | |-- disbursal_amount: double (nullable = true)\n", - " | |-- disbursal_currency: string (nullable = true)\n", - " | |-- disbursal_date: string (nullable = true)\n", - " | |-- loan_amount: long (nullable = true)\n", - " | |-- local_payments: array (nullable = true)\n", - " | | |-- element: struct (containsNull = true)\n", - " | | | |-- amount: double (nullable = true)\n", - " | | | |-- due_date: string (nullable = true)\n", - " | |-- loss_liability: struct (nullable = true)\n", - " | | |-- currency_exchange: string (nullable = true)\n", - " | | |-- currency_exchange_coverage_rate: double (nullable = true)\n", - " | | |-- nonpayment: string (nullable = true)\n", - " | |-- repayment_interval: string (nullable = true)\n", - " | |-- repayment_term: long (nullable = true)\n", - " | |-- scheduled_payments: array (nullable = true)\n", - " | | |-- element: struct (containsNull = true)\n", - " | | | |-- amount: double (nullable = true)\n", - " | | | |-- due_date: string (nullable = true)\n", - " |-- themes: array (nullable = true)\n", - " | |-- element: string (containsNull = true)\n", - " |-- translator: struct (nullable = true)\n", - " | |-- byline: string (nullable = true)\n", - " | |-- image: long (nullable = true)\n", - " |-- use: string (nullable = true)\n", - " |-- video: struct (nullable = true)\n", - " | |-- id: long (nullable = true)\n", - " | |-- thumbnailImageId: long (nullable = true)\n", - " | |-- title: string (nullable = true)\n", - " | |-- youtubeId: string (nullable = true)\n", - "\n" - ] - } - ], + "cell_type": "markdown", + "metadata": {}, "source": [ - "loans.printSchema()" + "# Custom Functions\n", + "\n", + "## Gender Ratio\n", + "\n", + "0 = All female\n", + "\n", + "1 = All male" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -144,7 +44,7 @@ "source": [ "import pyspark\n", "\n", - "def male_proportion(array):\n", + "def gender_ratio(array):\n", " num_males = 0\n", " for item in array:\n", " if item.gender == 'M':\n", @@ -152,55 +52,9 @@ " \n", " return float(num_males) / len(array)\n", "\n", - "sparkSql.udf.register('male_proportion',\n", - " male_proportion,\n", - " pyspark.sql.types.FloatType())\n", - "\n", - "train, validation, test = loans.randomSplit([.6, .2, .2], 101)\n", - "\n", - "query = '''\n", - "SELECT\n", - " id,\n", - " activity,\n", - " size(borrowers) as num_borrowers,\n", - " male_proportion(borrowers) as male_proportion,\n", - " lender_count,\n", - " location.country,\n", - " location.country_code,\n", - " partner_id,\n", - " sector,\n", - " tags,\n", - " DATEDIFF(terms.disbursal_date, planned_expiration_date) as loan_length,\n", - " terms.disbursal_amount,\n", - " terms.disbursal_currency,\n", - " terms.disbursal_date,\n", - " size(terms.scheduled_payments) as num_repayments,\n", - " terms.repayment_interval,\n", - " CASE WHEN\n", - " (status = 'refunded') OR\n", - " (status = 'defaulted') OR\n", - " (status = 'deleted') OR\n", - " (status = 'issue') OR\n", - " (status = 'inactive_expired') OR\n", - " (status = 'expired') OR\n", - " (status = 'inactive') OR\n", - " (delinquent = True) THEN 1 ELSE 0 END AS bad_loan,\n", - " gdp(location.country_code, terms.disbursal_date) as gdp,\n", - " xchange_rate(location.country_code, terms.disbursal_date) as xchange_rate,\n", - " status,\n", - " delinquent\n", - " \n", - "FROM {}\n", - "WHERE\n", - " status != 'fundraising' AND\n", - " status != 'funded'\n", - "'''\n", - "\n", - "train.registerTempTable('loans_train')\n", - "validation.registerTempTable('loans_validation')\n", - "test.registerTempTable('loans_test')\n", - "\n", - "sparkSql.sql(query.format('loans_validation')).write.json('validation_data-filtered.json')" + "sparkSql.udf.register('gender_ratio',\n", + " gender_ratio,\n", + " pyspark.sql.types.FloatType())" ] }, { @@ -214,7 +68,7 @@ "cell_type": "code", "execution_count": 3, "metadata": { - "collapsed": true + "collapsed": false }, "outputs": [], "source": [ @@ -225,17 +79,8 @@ "\n", "# Load country info data\n", "country_codes_raw = pd.read_csv('economic-data/country-codes.csv')\n", - "country_gdp_raw = pd.read_csv('economic-data/country-gdp.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ + "country_gdp_raw = pd.read_csv('economic-data/country-gdp.csv')\n", + "\n", "# Clean country codes data\n", "country_codes = country_codes_raw[['official_name_en', 'ISO3166-1-Alpha-2', \n", " 'ISO3166-1-Alpha-3', 'ISO4217-currency_alphabetic_code']]\n", @@ -243,18 +88,9 @@ "# Clean gdp data\n", "country_gdp = country_gdp_raw.drop(country_gdp_raw.columns[[0, 1]], axis=1)\n", "country_gdp.columns = ['name', 'country_code_3', '2002', '2003', '2004', '2005', '2006',\n", - " '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016']" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "# Merde gdp and code\n", + " '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016']\n", + "\n", + "# Merge gdp and code\n", "country_gdp = pd.merge(country_gdp, country_codes, left_on='country_code_3', right_on='ISO3166-1-Alpha-3', how='left')\n", "country_gdp.drop(['official_name_en', 'ISO3166-1-Alpha-3', 'country_code_3'], axis=1, inplace=True)\n", "country_gdp = country_gdp.rename(columns = {'ISO3166-1-Alpha-2':'country_code',\n", @@ -265,17 +101,8 @@ "cols = list(country_gdp.columns)\n", "cols.insert(1, cols.pop(cols.index('country_code')))\n", "cols.insert(2, cols.pop(cols.index('currency_code')))\n", - "country_gdp = country_gdp.reindex(columns= cols)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ + "country_gdp = country_gdp.reindex(columns= cols)\n", + "\n", "def gdp(country_code, disbursal_date):\n", " def historical_gdp(array):\n", " array = np.array(map(float, array))\n", @@ -316,37 +143,18 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "currencies_raw = pd.read_csv('economic-data/currencies.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ + "currencies_raw = pd.read_csv('economic-data/currencies.csv')\n", "# Cleanup\n", "currencies = currencies_raw.drop(country_gdp_raw.columns[[0, 1]], axis=1)\n", "currencies.columns = ['country_name', 'country_code_3', '2002', '2003', '2004', '2005', '2006',\n", - " '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016']" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ + " '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016']\n", + "\n", "# Get ISO 2 code\n", "currencies = pd.merge(currencies, country_codes, left_on='country_code_3', right_on='ISO3166-1-Alpha-3', how='left')\n", "currencies.drop(['official_name_en', 'ISO3166-1-Alpha-3', 'country_code_3'], axis=1, inplace=True)\n", @@ -362,17 +170,8 @@ "cols = list(currencies.columns)\n", "cols.insert(1, cols.pop(cols.index('country_code')))\n", "cols.insert(2, cols.pop(cols.index('currency_code')))\n", - "currencies = currencies.reindex(columns=cols)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ + "currencies = currencies.reindex(columns=cols)\n", + "\n", "def xchange_rate(country_code, disbursal_date):\n", " def historical_rates(array):\n", " array = np.array(map(float, array))\n", @@ -420,9 +219,161 @@ "sparkSql.udf.register('xchange_rate', xchange_rate, pyspark.sql.types.FloatType())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fetch actual data\n", + "\n", + "Get all data that we are going to use, get dummies, then split into train/validation/test." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Query our datasets to train on." + ] + }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "query = '''\n", + "SELECT\n", + " id,\n", + " activity,\n", + " size(borrowers) as num_borrowers,\n", + " gender_ratio(borrowers) as gender_ratio,\n", + " lender_count,\n", + " location.country,\n", + " location.country_code,\n", + " partner_id,\n", + " sector,\n", + " tags,\n", + " DATEDIFF(terms.disbursal_date, planned_expiration_date) as loan_length,\n", + " terms.disbursal_amount,\n", + " terms.disbursal_currency,\n", + " terms.disbursal_date,\n", + " size(terms.scheduled_payments) as num_repayments,\n", + " terms.repayment_interval,\n", + " CASE WHEN\n", + " (status = 'refunded') OR\n", + " (status = 'defaulted') OR\n", + " (status = 'deleted') OR\n", + " (status = 'issue') OR\n", + " (status = 'inactive_expired') OR\n", + " (status = 'expired') OR\n", + " (status = 'inactive') OR\n", + " (delinquent = True) THEN 1 ELSE 0 END AS bad_loan,\n", + " gdp(location.country_code, terms.disbursal_date) as gdp,\n", + " xchange_rate(location.country_code, terms.disbursal_date) as xchange_rate,\n", + " status,\n", + " delinquent\n", + " \n", + "FROM loans\n", + "WHERE\n", + " status != 'fundraising' AND\n", + " status != 'funded'\n", + "'''\n", + "\n", + "dataset = sparkSql.sql(query).toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Splits" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "X_columns = [\n", + " 'activity', 'num_borrowers', 'gender_ratio',\n", + " 'lender_count', 'country', 'partner_id', 'sector',\n", + " 'loan_length', 'disbursal_amount', 'disbursal_currency',\n", + " 'num_repayments', 'repayment_interval', 'gdp', 'xchange_rate'\n", + "]\n", + "\n", + "y_column = ['bad_loan']\n", + "\n", + "dummy_set = pd.get_dummies(dataset[X_columns + y_column])\n", + "dummy_set.to_csv('processed_dummy.csv')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can restart the kernel to clear memory, and start processing." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "processed_dummy = pd.read_csv('processed_dummy.csv', index_col=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "train, validate, test = np.split(processed_dummy.sample(frac=1, random_state=0),\n", + " [int(.6*len(processed_dummy)),\n", + " int(.8*len(processed_dummy))])\n", + "\n", + "train.to_csv('processed_train.csv')\n", + "validate.to_csv('processed_validate.csv')\n", + "test.to_csv('processed_test.csv')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Testing all the models" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "train = pd.read_csv('processed_train.csv', index_col=0).dropna(axis=1)\n", + "valid = pd.read_csv('processed_validate.csv', index_col=0).dropna(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": { "collapsed": false }, @@ -430,26 +381,82 @@ { "data": { "text/plain": [ - "[Row(id=507280, activity=u'Agriculture', num_borrowers=10, male_proportion=0.10000000149011612, lender_count=91, country=u'Rwanda', country_code=u'RW', partner_id=170, sector=u'Agriculture', tags=[], loan_length=-59, disbursal_amount=1500000.0, disbursal_currency=u'RWF', disbursal_date=u'2012-11-15T08:00:00Z', num_repayments=1, repayment_interval=u'At end of term', bad_loan=0, gdp=667.4146118164062, xchange_rate=614.295166015625, status=u'paid', delinquent=None),\n", - " Row(id=508860, activity=u'Agriculture', num_borrowers=1, male_proportion=1.0, lender_count=28, country=u'Rwanda', country_code=u'RW', partner_id=170, sector=u'Agriculture', tags=[], loan_length=-52, disbursal_amount=500000.0, disbursal_currency=u'RWF', disbursal_date=u'2012-11-26T08:00:00Z', num_repayments=1, repayment_interval=u'At end of term', bad_loan=0, gdp=667.4146118164062, xchange_rate=614.295166015625, status=u'paid', delinquent=None),\n", - " Row(id=498729, activity=u'Agriculture', num_borrowers=1, male_proportion=0.0, lender_count=6, country=u'Kenya', country_code=u'KE', partner_id=133, sector=u'Agriculture', tags=[], loan_length=-38, disbursal_amount=20000.0, disbursal_currency=u'KES', disbursal_date=u'2012-11-13T08:00:00Z', num_repayments=12, repayment_interval=u'Monthly', bad_loan=0, gdp=1184.9232177734375, xchange_rate=84.52960205078125, status=u'paid', delinquent=None),\n", - " Row(id=501877, activity=u'Agriculture', num_borrowers=1, male_proportion=1.0, lender_count=14, country=u'Peru', country_code=u'PE', partner_id=71, sector=u'Agriculture', tags=[], loan_length=-39, disbursal_amount=1000.0, disbursal_currency=u'PEN', disbursal_date=u'2012-11-20T08:00:00Z', num_repayments=8, repayment_interval=u'Monthly', bad_loan=0, gdp=6389.63037109375, xchange_rate=2.6375863552093506, status=u'paid', delinquent=None),\n", - " Row(id=504386, activity=u'Agriculture', num_borrowers=1, male_proportion=1.0, lender_count=16, country=u'Benin', country_code=u'BJ', partner_id=104, sector=u'Agriculture', tags=[], loan_length=-58, disbursal_amount=190000.0, disbursal_currency=u'XOF', disbursal_date=u'2012-11-08T08:00:00Z', num_repayments=4, repayment_interval=u'Irregularly', bad_loan=0, gdp=807.6884765625, xchange_rate=510.5271301269531, status=u'paid', delinquent=None),\n", - " Row(id=510144, activity=u'Agriculture', num_borrowers=1, male_proportion=1.0, lender_count=7, country=u'Senegal', country_code=u'SN', partner_id=108, sector=u'Agriculture', tags=[], loan_length=-53, disbursal_amount=150000.0, disbursal_currency=u'XOF', disbursal_date=u'2012-11-27T08:00:00Z', num_repayments=12, repayment_interval=u'Monthly', bad_loan=0, gdp=1019.272216796875, xchange_rate=510.5271301269531, status=u'paid', delinquent=None),\n", - " Row(id=497262, activity=u'Agriculture', num_borrowers=1, male_proportion=0.0, lender_count=11, country=u'Nicaragua', country_code=u'NI', partner_id=74, sector=u'Agriculture', tags=[], loan_length=-35, disbursal_amount=7000.0, disbursal_currency=u'NIO', disbursal_date=u'2012-11-14T08:00:00Z', num_repayments=1, repayment_interval=u'At end of term', bad_loan=0, gdp=1776.209228515625, xchange_rate=23.546663284301758, status=u'paid', delinquent=None),\n", - " Row(id=503327, activity=u'Agriculture', num_borrowers=1, male_proportion=0.0, lender_count=7, country=u'Mexico', country_code=u'MX', partner_id=224, sector=u'Agriculture', tags=[], loan_length=-7, disbursal_amount=3000.0, disbursal_currency=u'MXN', disbursal_date=u'2012-12-28T08:00:00Z', num_repayments=1, repayment_interval=u'At end of term', bad_loan=0, gdp=9720.5615234375, xchange_rate=13.169458389282227, status=u'paid', delinquent=None),\n", - " Row(id=500119, activity=u'Agriculture', num_borrowers=1, male_proportion=0.0, lender_count=30, country=u'Mexico', country_code=u'MX', partner_id=224, sector=u'Agriculture', tags=[], loan_length=6, disbursal_amount=12000.0, disbursal_currency=u'MXN', disbursal_date=u'2012-12-28T08:00:00Z', num_repayments=1, repayment_interval=u'At end of term', bad_loan=0, gdp=9720.5615234375, xchange_rate=13.169458389282227, status=u'paid', delinquent=None),\n", - " Row(id=153403, activity=u'Agriculture', num_borrowers=1, male_proportion=0.0, lender_count=37, country=u'Togo', country_code=u'TG', partner_id=22, sector=u'Agriculture', tags=[], loan_length=None, disbursal_amount=450000.0, disbursal_currency=u'XOF', disbursal_date=u'2009-10-26T07:00:00Z', num_repayments=14, repayment_interval=u'Irregularly', bad_loan=1, gdp=508.54052734375, xchange_rate=472.186279296875, status=u'defaulted', delinquent=True)]" + "342" ] }, - "execution_count": 13, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# sparkSql.sql(query.format('loans_validation')).take(10)\n", - "sparkSql.sql(query.format('loans_validation')).write.json('validation_data-filtered.json')" + "len(train.columns)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Naive guess:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.89836166750827584" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_x = train.drop('bad_loan', axis=1)\n", + "train_y = train['bad_loan']\n", + "valid_x = valid.drop('bad_loan', axis=1)\n", + "valid_y = valid['bad_loan']\n", + "\n", + "1 - train_y.mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SVM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from itertools import product\n", + "import pickle\n", + "from sklearn.svm import SVC\n", + "\n", + "svc_params = product([1, .5, 1.5], [.001, .01, .1])\n", + "\n", + "for C, gamma in svc_params:\n", + " svc = SVC(C=C, gamma=gamma)\n", + "\n", + " svc.fit(train_x, train_y)\n", + " with open('svc_{}_{}.pickle'.format(C, gamma), 'w') as handle:\n", + " pickle.dump(svc, handle)\n", + " \n", + " print(\"C: {}; gamma: {}; score: {}\".format(\n", + " C, gamma, svc.score(train_x, train_y)))" ] }, { @@ -459,7 +466,39 @@ "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "\n", + "# Number of columns is 342\n", + "for n_components in [342, 250, 150, 75]\n", + " lda = LinearDiscriminantAnalysis(n_components=n_components)\n", + " lda.fit(train_x, train_y)\n", + " with open('lda_{}.pickle'.format(n_components), 'w') as handle:\n", + " pickle.dump(lda, handle)\n", + " \n", + " print(\"N_components: {}; score: {}\".format(\n", + " n_components, lda.score(valid_x, valid_y)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "\n", + "for n_estimators in [10, 50, 75, 100]:\n", + " rf = RandomForestClassifier(n_estimators=n_estimators)\n", + " rf.fit(train_x, train_y)\n", + " with open('rf_{}.pickle'.format(n_estimators), 'w') as handle:\n", + " pickle.dump(rf, handle)\n", + " \n", + " print(\"N_estimators: {}; score: {}\".format(\n", + " n_estimators, score(valid_x, valid_y)))" + ] } ], "metadata": { @@ -478,7 +517,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.10" + "version": "2.7.12" } }, "nbformat": 4,