diff --git a/SubmissionSteps.ipynb b/SubmissionSteps.ipynb index 1740a93..8e7cec8 100755 --- a/SubmissionSteps.ipynb +++ b/SubmissionSteps.ipynb @@ -1,428 +1,362 @@ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Evaluation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import numpy as np\n", "from scipy.stats import pearsonr\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "network_drive_dir = '/mnt/E132-Projekte/Projects/2019_n2c2_challenge/'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "train_scores_true = pd.read_csv(network_drive_dir + 'clinicalSTS2019.train.txt', delimiter='\\t', header=None)[2].to_numpy()\n", "scores_true = pd.read_csv(network_drive_dir + 'output/submissions/8603970923442688/clinicalSTS2019.test.gs.sim.txt', header=None)[0].to_numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Enhanced Bert" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ - "0.8505066422830385 - Train scores step 2\n", - "0.8586733365753405 - Test scores step 1\n" + "0.8505066422830385 - Train scores step 2\n0.8586733365753403 - Test scores step 1\n" ] } ], "source": [ "step1_train = pd.read_csv(network_drive_dir + 'output/run1/step1/bert_output_scores_train.csv')['score'].to_numpy()\n", "step1_test = pd.read_csv(network_drive_dir + 'output/run1/step1/bert_output_scores_test.csv')['score'].to_numpy()\n", "print(f'{pearsonr(step1_train, train_scores_true)[0]} - Train scores step 2')\n", "print(f'{pearsonr(step1_test, scores_true)[0]} - Test scores step 1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Voting Regression" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ - "0.8603970923442688 - Train scores step 2\n", - "0.8491788063497707 - Test scores step 1\n" + "0.8603970923442686 - Train scores step 2\n0.8491788063497706 - Test scores step 1\n" ] } ], "source": [ "step2_train = pd.read_csv(network_drive_dir + 'output/run1/step2/some_features_0_dev_prediction.csv', header=None)[0].to_numpy()\n", "step2_test = pd.read_csv(network_drive_dir + 'output/run1/step2/some_features_0_test_prediction.csv', header=None)[0].to_numpy()\n", "print(f'{pearsonr(step2_train, train_scores_true)[0]} - Train scores step 2')\n", "print(f'{pearsonr(step2_test, scores_true)[0]} - Test scores step 1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Medication Graph" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ - "0.861858453761435 - Train scores step 2\n", - "0.8616985288938903 - Test scores step 1\n", - "same as in 2\n" + "0.8618584537614349 - Train scores step 2\n0.8616985288938904 - Test scores step 1\nsame as in 2\n" ] } ], "source": [ "step3_graph_train = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/graph_scores_train.csv')['score'].to_numpy()\n", "step3_graph_test = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/graph_scores_test.csv')['score'].to_numpy()\n", "step3_train = step2_train.copy()\n", "step3_train[step3_graph_train != 10] = step3_graph_train[step3_graph_train != 10]\n", "step3_test = step2_test.copy()\n", "step3_test[step3_graph_test != 10] = step3_graph_test[step3_graph_test != 10]\n", "print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n", "print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n", "print('same as in 2')" ] }, { + "source": [ + "sum(step3_graph_test != 10)" + ], "cell_type": "code", - "execution_count": 22, "metadata": {}, + "execution_count": 8, "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ - "array([10. , 10. , 10. , 1.7674538 , 10. ,\n", - " 1.57582981, 10. , 10. , 1.69849115, 10. ,\n", - " 10. , 1.51822006, 10. , 10. , 10. ,\n", - " 10. , 10. , 1.67447326, 10. , 10. ,\n", - " 10. , 10. , 1.54055718, 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.60288008, 10. , 1.6917981 , 10. , 10. ,\n", - " 1.65231379, 10. , 4.13268305, 10. , 10. ,\n", - " 10. , 10. , 10. , 1.64516726, 10. ,\n", - " 10. , 10. , 1.67076833, 10. , 10. ,\n", - " 1.51822055, 1.69984453, 1.56302873, 10. , 10. ,\n", - " 10. , 10. , 1.62634715, 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 1.69013926, 10. , 10. ,\n", - " 1.47596558, 10. , 10. , 1.7737868 , 10. ,\n", - " 10. , 10. , 10. , 10. , 1.65886264,\n", - " 10. , 1.72479863, 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.6338272 , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 1.69834204,\n", - " 1.58688106, 10. , 10. , 10. , 1.69001543,\n", - " 10. , 1.53229244, 10. , 10. , 10. ,\n", - " 10. , 10. , 1.68304674, 1.57258378, 10. ,\n", - " 10. , 4.34921153, 10. , 10. , 1.76747282,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 1.64439553,\n", - " 1.70810151, 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 1.58855146, 10. ,\n", - " 1.5868374 , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 1.55658649,\n", - " 10. , 1.70055615, 1.68510345, 10. , 10. ,\n", - " 10. , 1.55588571, 1.68000957, 10. , 10. ,\n", - " 1.69580423, 10. , 1.63544031, 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.69947344, 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.6862104 , 10. , 10. , 10. , 10. ,\n", - " 10. , 1.71018486, 10. , 10. , 10. ,\n", - " 10. , 10. , 1.59721677, 1.60418829, 10. ,\n", - " 1.70256379, 10. , 1.81674065, 10. , 10. ,\n", - " 10. , 1.75757853, 10. , 1.69382919, 10. ,\n", - " 10. , 10. , 10. , 1.61465125, 10. ,\n", - " 10. , 1.62855809, 1.62757836, 10. , 1.59920505,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 1.60490655, 10. , 10. , 10. ,\n", - " 10. , 1.60533072, 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 1.57447801, 1.70176241,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 1.72362936, 1.77422083, 10. ,\n", - " 10. , 1.71163931, 10. , 1.69005578, 10. ,\n", - " 10. , 10. , 10. , 1.78086753, 10. ,\n", - " 10. , 10. , 10. , 1.66440614, 1.55073141,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.7167565 , 10. , 1.68716474, 10. , 10. ,\n", - " 10. , 10. , 1.60996524, 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 1.74065919, 10. , 10. , 10. , 1.69126686,\n", - " 10. , 10. , 10. , 2.15816834, 10. ,\n", - " 10. , 1.63756216, 10. , 10. , 10. ,\n", - " 1.62615826, 1.69949813, 10. , 10. , 10. ,\n", - " 1.65159263, 10. , 10. , 10. , 1.6226005 ,\n", - " 10. , 10. , 1.78153677, 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 1.62805305, 10. ,\n", - " 10. , 1.69030582, 10. , 10. , 10. ,\n", - " 10. , 10. , 1.70186892, 10. , 10. ,\n", - " 1.69604549, 10. , 10. , 10. , 10. ,\n", - " 1.59645883, 1.75049379, 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. , 1.66117857, 10. , 1.78725624,\n", - " 10. , 10. , 1.67971476, 1.86105161, 10. ,\n", - " 10. , 10. , 10. , 10. , 1.52834597,\n", - " 10. , 1.57686315, 10. , 10. , 10. ,\n", - " 10. , 4.30486055, 10. , 10. , 10. ,\n", - " 10. , 1.6973243 , 1.50951602, 10. , 10. ,\n", - " 10. , 1.72869246, 10. , 10. , 10. ,\n", - " 10. , 10. , 10. , 10. , 1.5754832 ,\n", - " 10. , 10. , 10. , 10. , 10. ,\n", - " 10. , 10. ])" + "94" ] }, - "execution_count": 22, "metadata": {}, - "output_type": "execute_result" + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "148" + ] + }, + "metadata": {}, + "execution_count": 9 } ], "source": [ - "step3_graph_test" + "sum(step3_graph_train != 10)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.91962365, 3.02080456, 4.32715539, 1.7674538 , 1.53140277,\n", " 1.57582981, 2.69058631, 1.59874927, 1.69849115, 1.9051658 ,\n", " 1.8822309 , 1.51822006, 2.15805445, 1.4193063 , 3.44182167,\n", " 3.68950891, 4.43160836, 1.67447326, 4.63240929, 2.07942814,\n", " 2.76862427, 1.85719852, 1.54055718, 1.74520421, 2.2335856 ,\n", " 1.64414079, 2.40908164, 3.72529942, 3.05884908, 1.63124031,\n", " 1.60288008, 4.76373226, 1.6917981 , 3.35776882, 2.93615075,\n", " 1.65231379, 1.71684694, 4.13268305, 4.27047001, 3.84585319,\n", " 4.08149692, 3.18624199, 2.73008227, 1.64516726, 2.14033141,\n", " 2.39129219, 3.43961285, 1.67076833, 2.11208036, 2.12492612,\n", " 1.51822055, 1.69984453, 1.56302873, 1.94514012, 1.31966434,\n", " 3.18550331, 4.29035905, 1.62634715, 4.40107287, 1.44709537,\n", " 3.36152128, 1.64484862, 2.43473888, 3.63454681, 3.71400064,\n", " 4.38967013, 3.22783027, 1.69013926, 2.13835005, 1.70038675,\n", " 1.47596558, 0.94433552, 1.97221421, 1.7737868 , 3.92457313,\n", " 2.31992168, 3.91956857, 4.45567178, 2.20769868, 1.65886264,\n", " 1.36184832, 1.72479863, 4.21610542, 1.85339048, 1.71856811,\n", " 1.86391267, 2.92366817, 1.8488793 , 3.10642577, 2.48249743,\n", " 2.4919538 , 4.30604713, 1.23096491, 2.45489335, 3.03711203,\n", " 1.6338272 , 3.97904056, 3.77962891, 2.04450711, 3.38702596,\n", " 3.78519973, 3.64293611, 2.18667532, 3.31840785, 1.69834204,\n", " 1.58688106, 1.65787896, 2.41710945, 1.9228375 , 1.69001543,\n", " 4.13973586, 1.53229244, 1.67481343, 3.91139301, 3.96181135,\n", " 2.0334912 , 3.09225744, 1.68304674, 1.57258378, 3.05878476,\n", " 3.99347257, 4.34921153, 3.08045675, 2.14264782, 1.76747282,\n", " 1.58709927, 2.06912 , 4.069055 , 2.22283231, 4.07694842,\n", " 2.43285668, 1.49072362, 2.35695482, 3.91205839, 3.47863266,\n", " 1.65958811, 1.78430955, 2.38021665, 1.59949192, 1.64439553,\n", " 1.70810151, 4.17765358, 4.17608642, 3.87234784, 4.38938527,\n", " 1.31762992, 3.0725832 , 3.31467438, 1.58855146, 1.37983064,\n", " 1.5868374 , 0.94013128, 1.3929758 , 4.07037806, 2.40823212,\n", " 3.71738486, 1.76942317, 3.18742855, 1.91389529, 1.55658649,\n", " 3.4659005 , 1.70055615, 1.68510345, 1.39378786, 2.1099217 ,\n", " 1.89686412, 1.55588571, 1.68000957, 1.46804551, 1.5048956 ,\n", " 1.69580423, 2.62279439, 1.63544031, 2.7335974 , 3.92044456,\n", " 1.69068218, 1.61292242, 3.90594835, 4.31974304, 4.53325465,\n", " 1.69947344, 1.29012865, 4.15169143, 1.7339509 , 2.3134333 ,\n", " 4.22466654, 4.43978669, 3.49580178, 2.08742761, 2.11353594,\n", " 1.6862104 , 1.08351088, 2.53304925, 1.6245541 , 1.91957371,\n", " 4.12686945, 1.71018486, 2.35674727, 1.92881953, 1.42560422,\n", " 2.22854256, 2.15584282, 1.59721677, 1.60418829, 3.65929524,\n", " 1.70256379, 3.87029509, 1.81674065, 4.18622832, 2.04663929,\n", " 3.75671027, 1.75757853, 1.92724393, 1.69382919, 1.33292578,\n", " 3.45017525, 1.588931 , 3.81133735, 1.61465125, 1.23795452,\n", " 3.50098649, 1.62855809, 1.62757836, 2.23902587, 1.59920505,\n", " 1.64554549, 2.14390958, 1.62866843, 1.59751431, 2.49436411,\n", " 2.13731004, 1.60490655, 2.36171287, 3.91157873, 1.77449964,\n", " 4.2407261 , 1.60533072, 4.11296561, 1.34316987, 3.38236402,\n", " 2.62700257, 3.71982487, 1.8085862 , 1.57447801, 1.70176241,\n", " 3.71035709, 1.97845064, 3.08684286, 1.61046478, 1.75542087,\n", " 2.29367552, 3.48853812, 4.24122098, 4.27165522, 3.49592816,\n", " 2.17401834, 3.73973828, 1.72362936, 1.77422083, 2.16903226,\n", " 3.62461272, 1.71163931, 4.19810068, 1.69005578, 3.12736414,\n", " 2.9365779 , 3.71020531, 3.90583078, 1.78086753, 2.7620467 ,\n", " 1.64607358, 2.10031083, 3.86475059, 1.66440614, 1.55073141,\n", " 2.6050476 , 1.32454822, 2.74010879, 4.14889102, 3.83845493,\n", " 1.7167565 , 4.15179908, 1.68716474, 2.23055751, 4.42799125,\n", " 2.02344104, 2.0344298 , 1.60996524, 4.4121127 , 2.36971261,\n", " 1.51439188, 3.28725506, 2.7466327 , 2.23565972, 1.42885279,\n", " 3.37316893, 2.55641732, 2.48677688, 3.51434936, 2.71710757,\n", " 1.74065919, 3.89423443, 1.84143851, 1.45153321, 1.69126686,\n", " 1.20878499, 4.12768705, 3.93135253, 2.15816834, 4.25440236,\n", " 3.64090361, 1.63756216, 4.14144861, 2.85060833, 3.9150575 ,\n", " 1.62615826, 1.69949813, 1.62644883, 3.75231667, 4.36548079,\n", " 1.65159263, 2.9337424 , 3.61779266, 1.23706488, 1.6226005 ,\n", " 2.34443532, 3.6111164 , 1.78153677, 3.20187153, 2.30046119,\n", " 3.82115995, 3.71672393, 2.51533753, 3.17775269, 2.15910817,\n", " 3.90858714, 3.00489743, 2.28437026, 1.62805305, 4.28050426,\n", " 3.58575154, 1.69030582, 2.21792333, 2.04359441, 2.2373275 ,\n", " 2.34016778, 1.83259276, 1.70186892, 1.74213091, 3.02719001,\n", " 1.69604549, 1.91987975, 2.77698019, 1.69801652, 3.72743867,\n", " 1.59645883, 1.75049379, 4.00351329, 2.3714848 , 4.58482515,\n", " 4.36790784, 4.13196113, 2.60022676, 3.14720635, 3.49595544,\n", " 1.08128017, 3.5675986 , 1.66117857, 3.98749232, 1.78725624,\n", " 2.21801375, 2.29492981, 1.67971476, 1.86105161, 1.71484288,\n", " 4.46218263, 2.00993466, 2.05433472, 3.51673082, 1.52834597,\n", " 1.59793487, 1.57686315, 3.04636759, 2.15094979, 3.55612742,\n", " 4.37457148, 4.30486055, 2.57432097, 2.81452243, 3.94247364,\n", " 1.43113796, 1.6973243 , 1.50951602, 4.46484213, 3.8469918 ,\n", " 4.23737814, 1.72869246, 1.94759282, 4.4649204 , 2.90997844,\n", " 1.40385751, 4.2836937 , 2.44694886, 4.01425475, 1.5754832 ,\n", " 1.69643557, 4.00484059, 1.68132442, 4.54341418, 2.03674556,\n", " 1.75813428, 1.84689662])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "step3_test" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.861858453761435 - Train scores step 2\n", "same as in 2\n" ] } ], "source": [ "print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n", "#print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n", "print('same as in 2')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8603970923442688 - Train scores step 2\n", "0.8491788063497707 - Test scores step 1\n", "same as in 2\n" ] } ], "source": [ "step3_train = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/some_features_0_dev_prediction_step2.csv', header=None)[0].to_numpy()\n", "step3_test = pd.read_csv(network_drive_dir + 'output/run1/step3/preprocessed_data_2019-08-06_21-40-40/some_features_0_test_prediction_step2.csv', header=None)[0].to_numpy()\n", "print(f'{pearsonr(step3_train, train_scores_true)[0]} - Train scores step 2')\n", "print(f'{pearsonr(step3_test, scores_true)[0]} - Test scores step 1')\n", "print('same as in 2')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8604672658183281 - Train scores step 2\n", "0.8498357692426267 - Test scores step 1\n" ] } ], "source": [ "step4_train = pd.read_csv(network_drive_dir + 'output/run1/step4/some_features_0_dev_prediction.csv', header=None)[0].to_numpy()\n", "step4_test = pd.read_csv(network_drive_dir + 'output/run1/step4/some_features_0_test_prediction.csv', header=None)[0].to_numpy()\n", "print(f'{pearsonr(step4_train, train_scores_true)[0]} - Train scores step 2')\n", "print(f'{pearsonr(step4_test, scores_true)[0]} - Test scores step 1')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.8.5-final" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/mtc/ScoreAnalysis.ipynb b/mtc/ScoreAnalysis.ipynb new file mode 100755 index 0000000..bb24920 --- /dev/null +++ b/mtc/ScoreAnalysis.ipynb @@ -0,0 +1,457 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import numpy as np\n", + "from scipy.stats import pearsonr\n", + "from pathlib import Path\n", + "pd.set_option(\"display.max_rows\", None, \"display.max_columns\", None, 'display.max_colwidth', None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "path_data = Path(os.environ.get('NLP_RAW_DATA')) / 'n2c2'\n", + "path_results = Path(os.environ.get('NLP_EXPERIMENT_PATH')) / 'submission_generation' / '03_12_2020_20_18_37_original_data'\n", + "df_train = pd.read_csv(path_data / 'clinicalSTS2019.train.txt', delimiter='\\t', names=['sentence_a', 'sentence_b', 'score_true'])\n", + "df_train['score_step2'] = pd.read_csv(path_results / 'normal' / 'step2_train_scores.csv', header=None)[0].to_numpy()\n", + "df_train['score_step4'] = pd.read_csv(path_results / 'normal' / 'step4_train_scores.csv', header=None)[0].to_numpy()\n", + "df_train['score_diff'] = (abs(df_train['score_step4'] - df_train['score_true'])) - abs((df_train['score_step2'] - df_train['score_true']))\n", + "\n", + "df_train_med = df_train[np.abs(df_train['score_step2'] - df_train['score_step4']) > 0.001]\n", + "#df_train_med = df_train[df_train['score_step2'] != df_train['score_step4']]" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(2.787484774665043, 1.388712509029767)" + ] + }, + "metadata": {}, + "execution_count": 51 + } + ], + "source": [ + "(df_train['score_true'].mean(), df_train['score_true'].std())" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(147, 8.95249695493301)" + ] + }, + "metadata": {}, + "execution_count": 69 + } + ], + "source": [ + "(len(df_train_med), 100 * len(df_train_med) / len(df_train))" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(2.0285714285714285, 1.048857826269141)" + ] + }, + "metadata": {}, + "execution_count": 53 + } + ], + "source": [ + "(df_train_med['score_true'].mean(), df_train_med['score_true'].std())" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.6959206986788156" + ] + }, + "metadata": {}, + "execution_count": 54 + } + ], + "source": [ + "np.sum(np.abs(df_train_med['score_step2'] - df_train_med['score_true'])**2)/len(df_train_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.5808822487618591" + ] + }, + "metadata": {}, + "execution_count": 55 + } + ], + "source": [ + "np.sum(np.abs(df_train_med['score_step4'] - df_train_med['score_true'])**2)/len(df_train_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " sentence_a \\\n", + "1351 Prozac 20 mg capsule 1 capsule by mouth one time daily. \n", + "1289 ondansetron [ZOFRAN] 4 mg tablet 1 tablet by mouth three times a day as needed. \n", + "873 hydrochlorothiazide 25 mg tablet one-half tablet by mouth every morning. \n", + "479 Prozac 20 mg capsule 3 capsules by mouth one time daily. \n", + "1253 amlodipine [NORVASC] 5 mg tablet 2 tablets by mouth one time daily. \n", + "\n", + " sentence_b \\\n", + "1351 ibuprofen [ADVIL] 200 mg tablet 3 tablets by mouth one time daily as needed. \n", + "1289 amoxicillin [AMOXIL] 500 mg capsule 2 capsules by mouth three times a day. \n", + "873 ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth four times a day. \n", + "479 Aleve 220 mg tablet 1 tablet by mouth two times a day. \n", + "1253 hydrochlorothiazide 12.5 mg tablet 1 tablet by mouth one time daily. \n", + "\n", + " score_true score_step2 score_step4 score_diff \n", + "1351 1.5 1.718230 1.699896 -0.018334 \n", + "1289 3.0 1.683182 1.699137 -0.015955 \n", + "873 1.5 1.590499 1.697998 0.107499 \n", + "479 0.5 2.019769 1.678388 -0.341381 \n", + "1253 1.5 1.880972 1.700071 -0.180902 " + ], + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
sentence_asentence_bscore_truescore_step2score_step4score_diff
1351Prozac 20 mg capsule 1 capsule by mouth one time daily.ibuprofen [ADVIL] 200 mg tablet 3 tablets by mouth one time daily as needed.1.51.7182301.699896-0.018334
1289ondansetron [ZOFRAN] 4 mg tablet 1 tablet by mouth three times a day as needed.amoxicillin [AMOXIL] 500 mg capsule 2 capsules by mouth three times a day.3.01.6831821.699137-0.015955
873hydrochlorothiazide 25 mg tablet one-half tablet by mouth every morning.ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth four times a day.1.51.5904991.6979980.107499
479Prozac 20 mg capsule 3 capsules by mouth one time daily.Aleve 220 mg tablet 1 tablet by mouth two times a day.0.52.0197691.678388-0.341381
1253amlodipine [NORVASC] 5 mg tablet 2 tablets by mouth one time daily.hydrochlorothiazide 12.5 mg tablet 1 tablet by mouth one time daily.1.51.8809721.700071-0.180902
\n
" + }, + "metadata": {}, + "execution_count": 56 + } + ], + "source": [ + "np.random.seed(5)\n", + "df_train_med.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "df_test = pd.read_csv(path_data / 'clinicalSTS2019.test.txt', delimiter='\\t', names=['sentence_a', 'sentence_b'])\n", + "df_test['score_step2'] = pd.read_csv(path_results / 'normal' / 'step2_test_scores.csv', header=None)[0].to_numpy()\n", + "df_test['score_step4'] = pd.read_csv(path_results / 'normal' / 'step4_test_scores.csv', header=None)[0].to_numpy()\n", + "df_test['score_true'] = pd.read_csv(path_data / 'clinicalSTS2019.test.gs.sim.txt', header=None)[0].to_numpy()\n", + "df_test['score_diff'] = (abs(df_test['score_step4'] - df_test['score_true'])) - abs((df_test['score_step2'] - df_test['score_true']))\n", + "\n", + "df_test_med = df_test[np.abs(df_test['score_step2'] - df_test['score_step4']) > 0.001]\n", + "#df_test_med = df_test[df_test['score_step2'] != df_test['score_step4']]" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(1.7645631067961165, 1.5208707178893903)" + ] + }, + "metadata": {}, + "execution_count": 58 + } + ], + "source": [ + "(df_test['score_true'].mean(), df_test['score_true'].std())" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(94, 22.815533980582526)" + ] + }, + "metadata": {}, + "execution_count": 70 + } + ], + "source": [ + "(len(df_test_med), 100 * len(df_test_med) / len(df_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(1.0957446808510638, 0.5041579573299865)" + ] + }, + "metadata": {}, + "execution_count": 60 + } + ], + "source": [ + "(df_test_med['score_true'].mean(), df_test_med['score_true'].std())" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "2.432894107188975" + ] + }, + "metadata": {}, + "execution_count": 61 + } + ], + "source": [ + "np.sum(np.abs(df_test_med['score_step2'] - df_test_med['score_true'])**2)/len(df_test_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.5612688029859126" + ] + }, + "metadata": {}, + "execution_count": 62 + } + ], + "source": [ + "np.sum(np.abs(df_test_med['score_step4'] - df_test_med['score_true'])**2)/len(df_test_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " sentence_a \\\n", + "30 Qsymia 3.75-23 mg capsule multiphasic release 24 hour 1 capsule by mouth one time daily. \n", + "205 Aleve 220 mg tablet 1 tablet by mouth two times a day. \n", + "117 lisinopril [PRINIVIL/ZESTRIL] 10 mg tablet 2 tablets by mouth one time daily. \n", + "338 Tylenol Extra Strength 500 mg tablet 1 tablet by mouth as directed by prescriber as needed. \n", + "121 ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth every 6 hours as needed. \n", + "\n", + " sentence_b \\\n", + "30 Aleve 220 mg tablet 2 tablets by mouth one time daily as needed. \n", + "205 acetaminophen [TYLENOL] 500 mg tablet 2 tablets by mouth three times a day. \n", + "117 naproxen [NAPROSYN] 500 mg tablet 1 tablet by mouth two times a day. \n", + "338 furosemide [LASIX] 20 mg tablet 3 tablets by mouth one time daily. \n", + "121 ibuprofen [ADVIL] 200 mg tablet 2-3 tablets by mouth every 4 hours as needed. \n", + "\n", + " score_step2 score_step4 score_true score_diff \n", + "30 2.324836 1.661350 0.0 -0.663486 \n", + "205 2.736895 1.680362 1.5 -1.056533 \n", + "117 2.287069 1.691543 1.0 -0.595525 \n", + "338 1.877145 1.694849 1.0 -0.182296 \n", + "121 3.907524 4.261679 3.0 0.354155 " + ], + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
sentence_asentence_bscore_step2score_step4score_truescore_diff
30Qsymia 3.75-23 mg capsule multiphasic release 24 hour 1 capsule by mouth one time daily.Aleve 220 mg tablet 2 tablets by mouth one time daily as needed.2.3248361.6613500.0-0.663486
205Aleve 220 mg tablet 1 tablet by mouth two times a day.acetaminophen [TYLENOL] 500 mg tablet 2 tablets by mouth three times a day.2.7368951.6803621.5-1.056533
117lisinopril [PRINIVIL/ZESTRIL] 10 mg tablet 2 tablets by mouth one time daily.naproxen [NAPROSYN] 500 mg tablet 1 tablet by mouth two times a day.2.2870691.6915431.0-0.595525
338Tylenol Extra Strength 500 mg tablet 1 tablet by mouth as directed by prescriber as needed.furosemide [LASIX] 20 mg tablet 3 tablets by mouth one time daily.1.8771451.6948491.0-0.182296
121ibuprofen [MOTRIN] 600 mg tablet 1 tablet by mouth every 6 hours as needed.ibuprofen [ADVIL] 200 mg tablet 2-3 tablets by mouth every 4 hours as needed.3.9075244.2616793.00.354155
\n
" + }, + "metadata": {}, + "execution_count": 77 + } + ], + "source": [ + "np.random.seed(9)\n", + "df_test_med.sample(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## M-Heads" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "df_trainh = pd.read_csv(path_data / 'clinicalSTS2019.train.txt', delimiter='\\t', names=['sentence_a', 'sentence_b', 'score_true'])\n", + "df_trainh['score_step2'] = pd.read_csv(path_results / 'heads' / 'step1_train_scores.csv', header=None)[0].to_numpy()\n", + "df_trainh['score_step4'] = pd.read_csv(path_results / 'heads' / 'step4_train_scores.csv', header=None)[0].to_numpy()\n", + "df_trainh['score_diff'] = (abs(df_trainh['score_step4'] - df_trainh['score_true'])) - abs((df_trainh['score_step2'] - df_trainh['score_true']))\n", + "\n", + "df_trainh_med = df_trainh[np.abs(df_trainh['score_step2'] - df_trainh['score_step4']) > 0.001]\n", + "#df_trainh_med = df_trainh[df_trainh['score_step2'] != df_trainh['score_step4']]" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "146" + ] + }, + "metadata": {}, + "execution_count": 65 + } + ], + "source": [ + "len(df_trainh_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.6101250129135304" + ] + }, + "metadata": {}, + "execution_count": 66 + } + ], + "source": [ + "np.sum(np.abs(df_trainh_med['score_step2'] - df_trainh_med['score_true'])**2)/len(df_trainh_med)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.5969832368522427" + ] + }, + "metadata": {}, + "execution_count": 67 + } + ], + "source": [ + "np.sum(np.abs(df_trainh_med['score_step4'] - df_trainh_med['score_true'])**2)/len(df_trainh_med)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5-final" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file