diff --git a/SubmissionSteps.ipynb b/SubmissionSteps.ipynb
index 1740a93..8e7cec8 100755
--- a/SubmissionSteps.ipynb
+++ b/SubmissionSteps.ipynb
@@ -1,428 +1,362 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"from scipy.stats import pearsonr\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"network_drive_dir = '/mnt/E132-Projekte/Projects/2019_n2c2_challenge/'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_scores_true = pd.read_csv(network_drive_dir + 'clinicalSTS2019.train.txt', delimiter='\\t', header=None)[2].to_numpy()\n",
"scores_true = pd.read_csv(network_drive_dir + 'output/submissions/8603970923442688/clinicalSTS2019.test.gs.sim.txt', header=None)[0].to_numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Enhanced Bert"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
- "name": "stdout",
"output_type": "stream",
+ "name": "stdout",
"text": [
- "0.8505066422830385 - Train scores step 2\n",
- "0.8586733365753405 - Test scores step 1\n"
+ "0.8505066422830385 - Train scores step 2\n0.8586733365753403 - Test scores step 1\n"
]
}
],
"source": [
"step1_train = pd.read_csv(network_drive_dir + 'output/run1/step1/bert_output_scores_train.csv')['score'].to_numpy()\n",
"step1_test = pd.read_csv(network_drive_dir + 'output/run1/step1/bert_output_scores_test.csv')['score'].to_numpy()\n",
"print(f'{pearsonr(step1_train, train_scores_true)[0]} - Train scores step 2')\n",
"print(f'{pearsonr(step1_test, scores_true)[0]} - Test scores step 1')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Voting Regression"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
- "name": "stdout",
"output_type": "stream",
+ "name": "stdout",
"text": [
- "0.8603970923442688 - Train scores step 2\n",
- "0.8491788063497707 - Test scores step 1\n"
+ "0.8603970923442686 - Train scores step 2\n0.8491788063497706 - Test scores step 1\n"
]
}
],
"source": [
"step2_train = pd.read_csv(network_drive_dir + 'output/run1/step2/some_features_0_dev_prediction.csv', header=None)[0].to_numpy()\n",
"step2_test = pd.read_csv(network_drive_dir + 'output/run1/step2/some_features_0_test_prediction.csv', header=None)[0].to_numpy()\n",
"print(f'{pearsonr(step2_train, train_scores_true)[0]} - Train scores step 2')\n",
"print(f'{pearsonr(step2_test, scores_true)[0]} - Test scores step 1')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Medication Graph"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
- "name": "stdout",
"output_type": "stream",
+ "name": "stdout",
"text": [
- "0.861858453761435 - Train scores step 2\n",
- "0.8616985288938903 - Test scores step 1\n",
- "same as in 2\n"
+ "0.8618584537614349 - Train scores step 2\n0.8616985288938904 - Test scores step 1\nsame as in 2\n"
]
}
],
"source": [
"step3_graph_train = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/graph_scores_train.csv')['score'].to_numpy()\n",
"step3_graph_test = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/graph_scores_test.csv')['score'].to_numpy()\n",
"step3_train = step2_train.copy()\n",
"step3_train[step3_graph_train != 10] = step3_graph_train[step3_graph_train != 10]\n",
"step3_test = step2_test.copy()\n",
"step3_test[step3_graph_test != 10] = step3_graph_test[step3_graph_test != 10]\n",
"print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n",
"print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n",
"print('same as in 2')"
]
},
{
+ "source": [
+ "sum(step3_graph_test != 10)"
+ ],
"cell_type": "code",
- "execution_count": 22,
"metadata": {},
+ "execution_count": 8,
"outputs": [
{
+ "output_type": "execute_result",
"data": {
"text/plain": [
- "array([10. , 10. , 10. , 1.7674538 , 10. ,\n",
- " 1.57582981, 10. , 10. , 1.69849115, 10. ,\n",
- " 10. , 1.51822006, 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.67447326, 10. , 10. ,\n",
- " 10. , 10. , 1.54055718, 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.60288008, 10. , 1.6917981 , 10. , 10. ,\n",
- " 1.65231379, 10. , 4.13268305, 10. , 10. ,\n",
- " 10. , 10. , 10. , 1.64516726, 10. ,\n",
- " 10. , 10. , 1.67076833, 10. , 10. ,\n",
- " 1.51822055, 1.69984453, 1.56302873, 10. , 10. ,\n",
- " 10. , 10. , 1.62634715, 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.69013926, 10. , 10. ,\n",
- " 1.47596558, 10. , 10. , 1.7737868 , 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.65886264,\n",
- " 10. , 1.72479863, 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.6338272 , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.69834204,\n",
- " 1.58688106, 10. , 10. , 10. , 1.69001543,\n",
- " 10. , 1.53229244, 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.68304674, 1.57258378, 10. ,\n",
- " 10. , 4.34921153, 10. , 10. , 1.76747282,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.64439553,\n",
- " 1.70810151, 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 1.58855146, 10. ,\n",
- " 1.5868374 , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.55658649,\n",
- " 10. , 1.70055615, 1.68510345, 10. , 10. ,\n",
- " 10. , 1.55588571, 1.68000957, 10. , 10. ,\n",
- " 1.69580423, 10. , 1.63544031, 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.69947344, 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.6862104 , 10. , 10. , 10. , 10. ,\n",
- " 10. , 1.71018486, 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.59721677, 1.60418829, 10. ,\n",
- " 1.70256379, 10. , 1.81674065, 10. , 10. ,\n",
- " 10. , 1.75757853, 10. , 1.69382919, 10. ,\n",
- " 10. , 10. , 10. , 1.61465125, 10. ,\n",
- " 10. , 1.62855809, 1.62757836, 10. , 1.59920505,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 1.60490655, 10. , 10. , 10. ,\n",
- " 10. , 1.60533072, 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 1.57447801, 1.70176241,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.72362936, 1.77422083, 10. ,\n",
- " 10. , 1.71163931, 10. , 1.69005578, 10. ,\n",
- " 10. , 10. , 10. , 1.78086753, 10. ,\n",
- " 10. , 10. , 10. , 1.66440614, 1.55073141,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.7167565 , 10. , 1.68716474, 10. , 10. ,\n",
- " 10. , 10. , 1.60996524, 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 1.74065919, 10. , 10. , 10. , 1.69126686,\n",
- " 10. , 10. , 10. , 2.15816834, 10. ,\n",
- " 10. , 1.63756216, 10. , 10. , 10. ,\n",
- " 1.62615826, 1.69949813, 10. , 10. , 10. ,\n",
- " 1.65159263, 10. , 10. , 10. , 1.6226005 ,\n",
- " 10. , 10. , 1.78153677, 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 1.62805305, 10. ,\n",
- " 10. , 1.69030582, 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.70186892, 10. , 10. ,\n",
- " 1.69604549, 10. , 10. , 10. , 10. ,\n",
- " 1.59645883, 1.75049379, 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. , 1.66117857, 10. , 1.78725624,\n",
- " 10. , 10. , 1.67971476, 1.86105161, 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.52834597,\n",
- " 10. , 1.57686315, 10. , 10. , 10. ,\n",
- " 10. , 4.30486055, 10. , 10. , 10. ,\n",
- " 10. , 1.6973243 , 1.50951602, 10. , 10. ,\n",
- " 10. , 1.72869246, 10. , 10. , 10. ,\n",
- " 10. , 10. , 10. , 10. , 1.5754832 ,\n",
- " 10. , 10. , 10. , 10. , 10. ,\n",
- " 10. , 10. ])"
+ "94"
]
},
- "execution_count": 22,
"metadata": {},
- "output_type": "execute_result"
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "148"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 9
}
],
"source": [
- "step3_graph_test"
+ "sum(step3_graph_train != 10)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1.91962365, 3.02080456, 4.32715539, 1.7674538 , 1.53140277,\n",
" 1.57582981, 2.69058631, 1.59874927, 1.69849115, 1.9051658 ,\n",
" 1.8822309 , 1.51822006, 2.15805445, 1.4193063 , 3.44182167,\n",
" 3.68950891, 4.43160836, 1.67447326, 4.63240929, 2.07942814,\n",
" 2.76862427, 1.85719852, 1.54055718, 1.74520421, 2.2335856 ,\n",
" 1.64414079, 2.40908164, 3.72529942, 3.05884908, 1.63124031,\n",
" 1.60288008, 4.76373226, 1.6917981 , 3.35776882, 2.93615075,\n",
" 1.65231379, 1.71684694, 4.13268305, 4.27047001, 3.84585319,\n",
" 4.08149692, 3.18624199, 2.73008227, 1.64516726, 2.14033141,\n",
" 2.39129219, 3.43961285, 1.67076833, 2.11208036, 2.12492612,\n",
" 1.51822055, 1.69984453, 1.56302873, 1.94514012, 1.31966434,\n",
" 3.18550331, 4.29035905, 1.62634715, 4.40107287, 1.44709537,\n",
" 3.36152128, 1.64484862, 2.43473888, 3.63454681, 3.71400064,\n",
" 4.38967013, 3.22783027, 1.69013926, 2.13835005, 1.70038675,\n",
" 1.47596558, 0.94433552, 1.97221421, 1.7737868 , 3.92457313,\n",
" 2.31992168, 3.91956857, 4.45567178, 2.20769868, 1.65886264,\n",
" 1.36184832, 1.72479863, 4.21610542, 1.85339048, 1.71856811,\n",
" 1.86391267, 2.92366817, 1.8488793 , 3.10642577, 2.48249743,\n",
" 2.4919538 , 4.30604713, 1.23096491, 2.45489335, 3.03711203,\n",
" 1.6338272 , 3.97904056, 3.77962891, 2.04450711, 3.38702596,\n",
" 3.78519973, 3.64293611, 2.18667532, 3.31840785, 1.69834204,\n",
" 1.58688106, 1.65787896, 2.41710945, 1.9228375 , 1.69001543,\n",
" 4.13973586, 1.53229244, 1.67481343, 3.91139301, 3.96181135,\n",
" 2.0334912 , 3.09225744, 1.68304674, 1.57258378, 3.05878476,\n",
" 3.99347257, 4.34921153, 3.08045675, 2.14264782, 1.76747282,\n",
" 1.58709927, 2.06912 , 4.069055 , 2.22283231, 4.07694842,\n",
" 2.43285668, 1.49072362, 2.35695482, 3.91205839, 3.47863266,\n",
" 1.65958811, 1.78430955, 2.38021665, 1.59949192, 1.64439553,\n",
" 1.70810151, 4.17765358, 4.17608642, 3.87234784, 4.38938527,\n",
" 1.31762992, 3.0725832 , 3.31467438, 1.58855146, 1.37983064,\n",
" 1.5868374 , 0.94013128, 1.3929758 , 4.07037806, 2.40823212,\n",
" 3.71738486, 1.76942317, 3.18742855, 1.91389529, 1.55658649,\n",
" 3.4659005 , 1.70055615, 1.68510345, 1.39378786, 2.1099217 ,\n",
" 1.89686412, 1.55588571, 1.68000957, 1.46804551, 1.5048956 ,\n",
" 1.69580423, 2.62279439, 1.63544031, 2.7335974 , 3.92044456,\n",
" 1.69068218, 1.61292242, 3.90594835, 4.31974304, 4.53325465,\n",
" 1.69947344, 1.29012865, 4.15169143, 1.7339509 , 2.3134333 ,\n",
" 4.22466654, 4.43978669, 3.49580178, 2.08742761, 2.11353594,\n",
" 1.6862104 , 1.08351088, 2.53304925, 1.6245541 , 1.91957371,\n",
" 4.12686945, 1.71018486, 2.35674727, 1.92881953, 1.42560422,\n",
" 2.22854256, 2.15584282, 1.59721677, 1.60418829, 3.65929524,\n",
" 1.70256379, 3.87029509, 1.81674065, 4.18622832, 2.04663929,\n",
" 3.75671027, 1.75757853, 1.92724393, 1.69382919, 1.33292578,\n",
" 3.45017525, 1.588931 , 3.81133735, 1.61465125, 1.23795452,\n",
" 3.50098649, 1.62855809, 1.62757836, 2.23902587, 1.59920505,\n",
" 1.64554549, 2.14390958, 1.62866843, 1.59751431, 2.49436411,\n",
" 2.13731004, 1.60490655, 2.36171287, 3.91157873, 1.77449964,\n",
" 4.2407261 , 1.60533072, 4.11296561, 1.34316987, 3.38236402,\n",
" 2.62700257, 3.71982487, 1.8085862 , 1.57447801, 1.70176241,\n",
" 3.71035709, 1.97845064, 3.08684286, 1.61046478, 1.75542087,\n",
" 2.29367552, 3.48853812, 4.24122098, 4.27165522, 3.49592816,\n",
" 2.17401834, 3.73973828, 1.72362936, 1.77422083, 2.16903226,\n",
" 3.62461272, 1.71163931, 4.19810068, 1.69005578, 3.12736414,\n",
" 2.9365779 , 3.71020531, 3.90583078, 1.78086753, 2.7620467 ,\n",
" 1.64607358, 2.10031083, 3.86475059, 1.66440614, 1.55073141,\n",
" 2.6050476 , 1.32454822, 2.74010879, 4.14889102, 3.83845493,\n",
" 1.7167565 , 4.15179908, 1.68716474, 2.23055751, 4.42799125,\n",
" 2.02344104, 2.0344298 , 1.60996524, 4.4121127 , 2.36971261,\n",
" 1.51439188, 3.28725506, 2.7466327 , 2.23565972, 1.42885279,\n",
" 3.37316893, 2.55641732, 2.48677688, 3.51434936, 2.71710757,\n",
" 1.74065919, 3.89423443, 1.84143851, 1.45153321, 1.69126686,\n",
" 1.20878499, 4.12768705, 3.93135253, 2.15816834, 4.25440236,\n",
" 3.64090361, 1.63756216, 4.14144861, 2.85060833, 3.9150575 ,\n",
" 1.62615826, 1.69949813, 1.62644883, 3.75231667, 4.36548079,\n",
" 1.65159263, 2.9337424 , 3.61779266, 1.23706488, 1.6226005 ,\n",
" 2.34443532, 3.6111164 , 1.78153677, 3.20187153, 2.30046119,\n",
" 3.82115995, 3.71672393, 2.51533753, 3.17775269, 2.15910817,\n",
" 3.90858714, 3.00489743, 2.28437026, 1.62805305, 4.28050426,\n",
" 3.58575154, 1.69030582, 2.21792333, 2.04359441, 2.2373275 ,\n",
" 2.34016778, 1.83259276, 1.70186892, 1.74213091, 3.02719001,\n",
" 1.69604549, 1.91987975, 2.77698019, 1.69801652, 3.72743867,\n",
" 1.59645883, 1.75049379, 4.00351329, 2.3714848 , 4.58482515,\n",
" 4.36790784, 4.13196113, 2.60022676, 3.14720635, 3.49595544,\n",
" 1.08128017, 3.5675986 , 1.66117857, 3.98749232, 1.78725624,\n",
" 2.21801375, 2.29492981, 1.67971476, 1.86105161, 1.71484288,\n",
" 4.46218263, 2.00993466, 2.05433472, 3.51673082, 1.52834597,\n",
" 1.59793487, 1.57686315, 3.04636759, 2.15094979, 3.55612742,\n",
" 4.37457148, 4.30486055, 2.57432097, 2.81452243, 3.94247364,\n",
" 1.43113796, 1.6973243 , 1.50951602, 4.46484213, 3.8469918 ,\n",
" 4.23737814, 1.72869246, 1.94759282, 4.4649204 , 2.90997844,\n",
" 1.40385751, 4.2836937 , 2.44694886, 4.01425475, 1.5754832 ,\n",
" 1.69643557, 4.00484059, 1.68132442, 4.54341418, 2.03674556,\n",
" 1.75813428, 1.84689662])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"step3_test"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.861858453761435 - Train scores step 2\n",
"same as in 2\n"
]
}
],
"source": [
"print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n",
"#print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n",
"print('same as in 2')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8603970923442688 - Train scores step 2\n",
"0.8491788063497707 - Test scores step 1\n",
"same as in 2\n"
]
}
],
"source": [
"step3_train = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/some_features_0_dev_prediction_step2.csv', header=None)[0].to_numpy()\n",
"step3_test = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/some_features_0_test_prediction_step2.csv', header=None)[0].to_numpy()\n",
"print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n",
"print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n",
"print('same as in 2')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8604672658183281 - Train scores step 2\n",
"0.8498357692426267 - Test scores step 1\n"
]
}
],
"source": [
"step4_train = pd.read_csv(network_drive_dir + 'output/run1/step4/some_features_0_dev_prediction.csv', header=None)[0].to_numpy()\n",
"step4_test = pd.read_csv(network_drive_dir + 'output/run1/step4/some_features_0_test_prediction.csv', header=None)[0].to_numpy()\n",
"print(f'{pearsonr(step4_train, train_scores_true)[0]} - Train scores step 2')\n",
"print(f'{pearsonr(step4_test, scores_true)[0]} - Test scores step 1')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.8"
+ "version": "3.8.5-final"
}
},
"nbformat": 4,
"nbformat_minor": 4
-}
+}
\ No newline at end of file
diff --git a/mtc/ScoreAnalysis.ipynb b/mtc/ScoreAnalysis.ipynb
new file mode 100755
index 0000000..bb24920
--- /dev/null
+++ b/mtc/ScoreAnalysis.ipynb
@@ -0,0 +1,457 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Test Evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from scipy.stats import pearsonr\n",
+ "from pathlib import Path\n",
+ "pd.set_option(\"display.max_rows\", None, \"display.max_columns\", None, 'display.max_colwidth', None)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path_data = Path(os.environ.get('NLP_RAW_DATA')) / 'n2c2'\n",
+ "path_results = Path(os.environ.get('NLP_EXPERIMENT_PATH')) / 'submission_generation' / '03_12_2020_20_18_37_original_data'\n",
+ "df_train = pd.read_csv(path_data / 'clinicalSTS2019.train.txt', delimiter='\\t', names=['sentence_a', 'sentence_b', 'score_true'])\n",
+ "df_train['score_step2'] = pd.read_csv(path_results / 'normal' / 'step2_train_scores.csv', header=None)[0].to_numpy()\n",
+ "df_train['score_step4'] = pd.read_csv(path_results / 'normal' / 'step4_train_scores.csv', header=None)[0].to_numpy()\n",
+ "df_train['score_diff'] = (abs(df_train['score_step4'] - df_train['score_true'])) - abs((df_train['score_step2'] - df_train['score_true']))\n",
+ "\n",
+ "df_train_med = df_train[np.abs(df_train['score_step2'] - df_train['score_step4']) > 0.001]\n",
+ "#df_train_med = df_train[df_train['score_step2'] != df_train['score_step4']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(2.787484774665043, 1.388712509029767)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 51
+ }
+ ],
+ "source": [
+ "(df_train['score_true'].mean(), df_train['score_true'].std())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(147, 8.95249695493301)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 69
+ }
+ ],
+ "source": [
+ "(len(df_train_med), 100 * len(df_train_med) / len(df_train))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(2.0285714285714285, 1.048857826269141)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 53
+ }
+ ],
+ "source": [
+ "(df_train_med['score_true'].mean(), df_train_med['score_true'].std())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.6959206986788156"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 54
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_train_med['score_step2'] - df_train_med['score_true'])**2)/len(df_train_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.5808822487618591"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 55
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_train_med['score_step4'] - df_train_med['score_true'])**2)/len(df_train_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " sentence_a \\\n",
+ "1351 Prozac 20 mg capsule 1 capsule by mouth one time daily. \n",
+ "1289 ondansetron [ZOFRAN] 4 mg tablet 1 tablet by mouth three times a day as needed. \n",
+ "873 hydrochlorothiazide 25 mg tablet one-half tablet by mouth every morning. \n",
+ "479 Prozac 20 mg capsule 3 capsules by mouth one time daily. \n",
+ "1253 amlodipine [NORVASC] 5 mg tablet 2 tablets by mouth one time daily. \n",
+ "\n",
+ " sentence_b \\\n",
+ "1351 ibuprofen [ADVIL] 200 mg tablet 3 tablets by mouth one time daily as needed. \n",
+ "1289 amoxicillin [AMOXIL] 500 mg capsule 2 capsules by mouth three times a day. \n",
+ "873 ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth four times a day. \n",
+ "479 Aleve 220 mg tablet 1 tablet by mouth two times a day. \n",
+ "1253 hydrochlorothiazide 12.5 mg tablet 1 tablet by mouth one time daily. \n",
+ "\n",
+ " score_true score_step2 score_step4 score_diff \n",
+ "1351 1.5 1.718230 1.699896 -0.018334 \n",
+ "1289 3.0 1.683182 1.699137 -0.015955 \n",
+ "873 1.5 1.590499 1.697998 0.107499 \n",
+ "479 0.5 2.019769 1.678388 -0.341381 \n",
+ "1253 1.5 1.880972 1.700071 -0.180902 "
+ ],
+ "text/html": "
\n\n
\n \n \n | \n sentence_a | \n sentence_b | \n score_true | \n score_step2 | \n score_step4 | \n score_diff | \n
\n \n \n \n 1351 | \n Prozac 20 mg capsule 1 capsule by mouth one time daily. | \n ibuprofen [ADVIL] 200 mg tablet 3 tablets by mouth one time daily as needed. | \n 1.5 | \n 1.718230 | \n 1.699896 | \n -0.018334 | \n
\n \n 1289 | \n ondansetron [ZOFRAN] 4 mg tablet 1 tablet by mouth three times a day as needed. | \n amoxicillin [AMOXIL] 500 mg capsule 2 capsules by mouth three times a day. | \n 3.0 | \n 1.683182 | \n 1.699137 | \n -0.015955 | \n
\n \n 873 | \n hydrochlorothiazide 25 mg tablet one-half tablet by mouth every morning. | \n ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth four times a day. | \n 1.5 | \n 1.590499 | \n 1.697998 | \n 0.107499 | \n
\n \n 479 | \n Prozac 20 mg capsule 3 capsules by mouth one time daily. | \n Aleve 220 mg tablet 1 tablet by mouth two times a day. | \n 0.5 | \n 2.019769 | \n 1.678388 | \n -0.341381 | \n
\n \n 1253 | \n amlodipine [NORVASC] 5 mg tablet 2 tablets by mouth one time daily. | \n hydrochlorothiazide 12.5 mg tablet 1 tablet by mouth one time daily. | \n 1.5 | \n 1.880972 | \n 1.700071 | \n -0.180902 | \n
\n \n
\n
"
+ },
+ "metadata": {},
+ "execution_count": 56
+ }
+ ],
+ "source": [
+ "np.random.seed(5)\n",
+ "df_train_med.sample(5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_test = pd.read_csv(path_data / 'clinicalSTS2019.test.txt', delimiter='\\t', names=['sentence_a', 'sentence_b'])\n",
+ "df_test['score_step2'] = pd.read_csv(path_results / 'normal' / 'step2_test_scores.csv', header=None)[0].to_numpy()\n",
+ "df_test['score_step4'] = pd.read_csv(path_results / 'normal' / 'step4_test_scores.csv', header=None)[0].to_numpy()\n",
+ "df_test['score_true'] = pd.read_csv(path_data / 'clinicalSTS2019.test.gs.sim.txt', header=None)[0].to_numpy()\n",
+ "df_test['score_diff'] = (abs(df_test['score_step4'] - df_test['score_true'])) - abs((df_test['score_step2'] - df_test['score_true']))\n",
+ "\n",
+ "df_test_med = df_test[np.abs(df_test['score_step2'] - df_test['score_step4']) > 0.001]\n",
+ "#df_test_med = df_test[df_test['score_step2'] != df_test['score_step4']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(1.7645631067961165, 1.5208707178893903)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 58
+ }
+ ],
+ "source": [
+ "(df_test['score_true'].mean(), df_test['score_true'].std())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(94, 22.815533980582526)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 70
+ }
+ ],
+ "source": [
+ "(len(df_test_med), 100 * len(df_test_med) / len(df_test))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "(1.0957446808510638, 0.5041579573299865)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 60
+ }
+ ],
+ "source": [
+ "(df_test_med['score_true'].mean(), df_test_med['score_true'].std())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "2.432894107188975"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 61
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_test_med['score_step2'] - df_test_med['score_true'])**2)/len(df_test_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.5612688029859126"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 62
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_test_med['score_step4'] - df_test_med['score_true'])**2)/len(df_test_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " sentence_a \\\n",
+ "30 Qsymia 3.75-23 mg capsule multiphasic release 24 hour 1 capsule by mouth one time daily. \n",
+ "205 Aleve 220 mg tablet 1 tablet by mouth two times a day. \n",
+ "117 lisinopril [PRINIVIL/ZESTRIL] 10 mg tablet 2 tablets by mouth one time daily. \n",
+ "338 Tylenol Extra Strength 500 mg tablet 1 tablet by mouth as directed by prescriber as needed. \n",
+ "121 ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth every 6 hours as needed. \n",
+ "\n",
+ " sentence_b \\\n",
+ "30 Aleve 220 mg tablet 2 tablets by mouth one time daily as needed. \n",
+ "205 acetaminophen [TYLENOL] 500 mg tablet 2 tablets by mouth three times a day. \n",
+ "117 naproxen [NAPROSYN] 500 mg tablet 1 tablet by mouth two times a day. \n",
+ "338 furosemide [LASIX] 20 mg tablet 3 tablets by mouth one time daily. \n",
+ "121 ibuprofen [ADVIL] 200 mg tablet 2-3 tablets by mouth every 4 hours as needed. \n",
+ "\n",
+ " score_step2 score_step4 score_true score_diff \n",
+ "30 2.324836 1.661350 0.0 -0.663486 \n",
+ "205 2.736895 1.680362 1.5 -1.056533 \n",
+ "117 2.287069 1.691543 1.0 -0.595525 \n",
+ "338 1.877145 1.694849 1.0 -0.182296 \n",
+ "121 3.907524 4.261679 3.0 0.354155 "
+ ],
+ "text/html": "\n\n
\n \n \n | \n sentence_a | \n sentence_b | \n score_step2 | \n score_step4 | \n score_true | \n score_diff | \n
\n \n \n \n 30 | \n Qsymia 3.75-23 mg capsule multiphasic release 24 hour 1 capsule by mouth one time daily. | \n Aleve 220 mg tablet 2 tablets by mouth one time daily as needed. | \n 2.324836 | \n 1.661350 | \n 0.0 | \n -0.663486 | \n
\n \n 205 | \n Aleve 220 mg tablet 1 tablet by mouth two times a day. | \n acetaminophen [TYLENOL] 500 mg tablet 2 tablets by mouth three times a day. | \n 2.736895 | \n 1.680362 | \n 1.5 | \n -1.056533 | \n
\n \n 117 | \n lisinopril [PRINIVIL/ZESTRIL] 10 mg tablet 2 tablets by mouth one time daily. | \n naproxen [NAPROSYN] 500 mg tablet 1 tablet by mouth two times a day. | \n 2.287069 | \n 1.691543 | \n 1.0 | \n -0.595525 | \n
\n \n 338 | \n Tylenol Extra Strength 500 mg tablet 1 tablet by mouth as directed by prescriber as needed. | \n furosemide [LASIX] 20 mg tablet 3 tablets by mouth one time daily. | \n 1.877145 | \n 1.694849 | \n 1.0 | \n -0.182296 | \n
\n \n 121 | \n ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth every 6 hours as needed. | \n ibuprofen [ADVIL] 200 mg tablet 2-3 tablets by mouth every 4 hours as needed. | \n 3.907524 | \n 4.261679 | \n 3.0 | \n 0.354155 | \n
\n \n
\n
"
+ },
+ "metadata": {},
+ "execution_count": 77
+ }
+ ],
+ "source": [
+ "np.random.seed(9)\n",
+ "df_test_med.sample(5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## M-Heads"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_trainh = pd.read_csv(path_data / 'clinicalSTS2019.train.txt', delimiter='\\t', names=['sentence_a', 'sentence_b', 'score_true'])\n",
+ "df_trainh['score_step2'] = pd.read_csv(path_results / 'heads' / 'step1_train_scores.csv', header=None)[0].to_numpy()\n",
+ "df_trainh['score_step4'] = pd.read_csv(path_results / 'heads' / 'step4_train_scores.csv', header=None)[0].to_numpy()\n",
+ "df_trainh['score_diff'] = (abs(df_trainh['score_step4'] - df_trainh['score_true'])) - abs((df_trainh['score_step2'] - df_trainh['score_true']))\n",
+ "\n",
+ "df_trainh_med = df_trainh[np.abs(df_trainh['score_step2'] - df_trainh['score_step4']) > 0.001]\n",
+ "#df_trainh_med = df_trainh[df_trainh['score_step2'] != df_trainh['score_step4']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 65,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "146"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 65
+ }
+ ],
+ "source": [
+ "len(df_trainh_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.6101250129135304"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 66
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_trainh_med['score_step2'] - df_trainh_med['score_true'])**2)/len(df_trainh_med)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.5969832368522427"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 67
+ }
+ ],
+ "source": [
+ "np.sum(np.abs(df_trainh_med['score_step4'] - df_trainh_med['score_true'])**2)/len(df_trainh_med)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5-final"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file