/
LearningCurve.groovy
LearningCurve.groovy
import java.beans.* import java.io.Serializable import java.util.Vector import java.util.Enumeration import org.pentaho.dm.kf.KFGroovyScript import org.pentaho.dm.kf.GroovyHelper import weka.core.* import weka.gui.Logger import weka.gui.beans.* import weka.classifiers.bayes.NaiveBayes import weka.classifiers.functions.Logistic import weka.classifiers.Evaluation import weka.classifiers.Classifier import groovy.swing.SwingBuilder import javax.swing.* import java.awt.* // add further imports here if necessary /** * Example Groovy script that generates a learning curve for a classifier. * Allows the classifier to be connected via a "configuration" event or * specified via an environment variable (CLASSIFIER_NAME). Classifier options * and the parameters of the learning curve could be specified via environment * variables as well through just minor changes to the script. * * Generates both a "TextEvent" containing the curve information and a * "DataSetEvent". The latter can be visualized in a DataVisualizer component. * * Also demonstrates how to allow the user to set options for the script * via a graphical pop-up window. * * @author Mark Hall (mhall{[at]}pentaho{[dot]}org) */ class LearningCurve implements KFGroovyScript, EnvironmentHandler, BeanCommon, EventConstraints, UserRequestAcceptor, TrainingSetListener, TestSetListener, DataSourceListener, InstanceListener, TextListener, BatchClassifierListener, IncrementalClassifierListener, BatchClustererListener, GraphListener, ChartListener, ThresholdDataListener, VisualizableErrorListener, ConfigurationListener, Serializable { /** Don't delete!! * GroovyHelper has the following useful methods: * * notifyListenerType(Object event) - GroovyHelper will pass on event * appropriate listener type for you * ArrayList<TrainingSetListener> getTrainingSetListeners() - get * a list of any directly connected components that are listening * for TrainingSetEvents from us * ArrayList<TestSetListener> getTestSetListeners() * ArrayList<InstanceListener> getInstanceListeners() * ArrayList<TextListener> getTextListeners() * ArrayList<DataSourceListener> getDataSourceListeners() * ArrayList<BatchClassifierListener> getBatchClassifierListeners() * ArrayList<IncrementalClassifierListener> getIncrementalClassifierListeners() * ArrayList<BatchClustererListener> getBatchClustererListeners() * ArrayList<GraphListenerListener> getGraphListeners() * ArrayList<ChartListener> getChartListeners() * ArrayList<ThresholdDataListener> getThresholdDataListeners() * ArrayList<VisualizableErrorListener> getVisualizableErrorListeners() */ GroovyHelper m_helper Logger m_log = null Environment m_env = Environment.getSystemWide() String m_holdoutSize = "33.0" String m_stepSize = "100" String m_numSteps = "10" String m_classifierName = "\${CLASSIFIER_NAME}" String m_classifierOptions = null Object m_incomingConnection = null weka.gui.beans.Classifier m_connectedConfigurable = null /** Don't delete!! */ void setManager(GroovyHelper manager) { m_helper = manager } /** Alter or add to in order to tell the KnowlegeFlow * environment whether a certain incoming connection type is allowed */ boolean connectionAllowed(String eventName) { if (eventName.equals("trainingSet") && m_incomingConnection == null) { return true } if (eventName.equals("configuration") && m_connectedConfigurable == null) { return true} return false } /** Alter or add to in order to tell the KnowlegeFlow * environment whether a certain incoming connection type is allowed */ boolean connectionAllowed(EventSetDescriptor esd) { return connectionAllowed(esd.getName()) } /** Add (optional) code to do something when you have been * registered as a listener with a source for the named event */ void connectionNotification(String eventName, Object source) { if (eventName.equals("trainingSet")) { m_incomingConnection = source } if (eventName.equals("configuration")) { // check the type of the configurable if (source instanceof weka.gui.beans.Classifier) { m_connectedConfigurable = (weka.gui.beans.Classifier)source } else { if (m_log != null) { m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)") m_log.logMessage("[LearningCurve] Connected configurable is not a classifier!!") } } } } /** Add (optional) code to do something when you have been * deregistered as a listener with a source for the named event */ void disconnectionNotification(String eventName, Object source) { if (eventName.equals("trainingSet")) { m_incomingConnection = null } if (eventName.equals("configuration")) { m_connectedConfigurable = null } } /** Custom name of this component. Do something with it if you * like. GroovyHelper already stores it and alters the icon text * for you */ void setCustomName(String name) { } /** Custom name of this component. No need to return anything * GroovyHelper already stores it and alters the icon text * for you */ String getCustomName() { return null } /** Add code to return true when you are busy doing something */ boolean isBusy() { return false } /** Store and use this logging object in order to post messages * to the log */ void setLog(Logger logger) { m_log = logger } /** Store and use this Environment object in order to lookup and * use the values of environment variables */ void setEnvironment(Environment env) { m_env = env } /** Stop any processing (if possible) */ void stop() { } /** Alter or add to in order to tell the KnowlegeFlow * whether, at the current time, the named event could * be generated. */ boolean eventGeneratable(String eventName) { if (eventName.equals("text")) { return true } if (eventName.equals("dataSet")) { return true } return false } /** Implement this to tell KnowledgeFlow about any methods * that the user could invoke (i.e. to show a popup visualization * or something). */ Enumeration enumerateRequests() { Vector items = new Vector(0) items.add("Set options...") return items.elements() } /** Make the user-requested action happen here. */ void performRequest(String requestName) { if (requestName.equals("Set options...")) { def swing = new SwingBuilder() def holderP1 = { swing.panel() { borderLayout() label (text:'Holdout set size: ', constraints:BorderLayout.WEST) hSize = textField(text:m_holdoutSize, columns:6, actionPerformed: { m_holdoutSize = hSize.text }, constraints:BorderLayout.CENTER) } } def holderP2 = { swing.panel() { borderLayout() label (text:'Number of steps: ', constraints:BorderLayout.WEST) nSteps = textField(text:m_numSteps, columns:6, actionPerformed: { m_numSteps = nSteps.text }, constraints:BorderLayout.CENTER) } } def holderP3 = { swing.panel() { borderLayout() label (text:'Step size: ', constraints:BorderLayout.WEST) sSize = textField(text:m_stepSize, columns:6, actionPerformed: { m_stepSize = sSize.text }, constraints:BorderLayout.CENTER) } } def holderP4 = { swing.panel() { boxLayout(axis:BoxLayout.Y_AXIS) widget(holderP1()) widget(holderP2()) widget(holderP3()) } } def holderP5 = { swing.panel() { boxLayout(axis:BoxLayout.X_AXIS) button(text:'OK', actionPerformed: { m_holdoutSize = hSize.text m_numSteps = nSteps.text m_stepSize = sSize.text dispose() }) button(text:"CANCEL", actionPerformed: { dispose() }) } } def frame = swing.frame(title:'Learning Curve Options', size:[300,600]) { borderLayout() widget(holderP4(), constraints:BorderLayout.NORTH) widget(holderP5(), constraints:BorderLayout.SOUTH) } frame.pack() frame.show() } } //--------------- Incoming events ------------------ //--------------- Implement as necessary ----------- void acceptTrainingSet(TrainingSetEvent e) { if (e.isStructureOnly()) { return } StringBuffer buff = new StringBuffer() Instances insts = new Instances(e.getTrainingSet()) insts.randomize(new Random(1)) String hSize = m_holdoutSize String sSize = m_stepSize String nSteps = m_numSteps String classifierName = m_classifierName String classifierOptions = m_classifierOptions String[] splitOptions = null if (m_env != null) { try { hSize = m_env.substitute(hSize) sSize = m_env.substitute(sSize) nSteps = m_env.substitute(nSteps) if (classifierName != null && classifierName.length() > 0) { classifierName = m_env.substitute(classifierName) } if (classifierOptions != null && classifierOptions.length() > 0) { classifierOptions = m_env.substitute(classifierOptions) } } catch (Exception ex) { } } weka.classifiers.Classifier classifierToUse = null if (m_connectedConfigurable == null) { // try and instantiate from the supplied classifier name if (classifierName == null || classifierName.length() == 0) { if (m_log != null) { m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)") m_log.logMessage("[LearningCurve] No classifier supplied!") } return } if (classifierOptions != null && classifierOptions.length() > 0) { try { splitOptions = Utils.splitOptions(classifierOptions) } catch (Exception ex) { if (m_log != null) { m_log.statusMessage("LearningCurve\$"+hashCode()+"ERROR (see log for details)") m_log.logMessage("[LearningCurve] Problem parsing classifier options") } return } } classifierToUse = Classifier.forName(classifierName, splitOptions) } else { classifierToUse = m_connectedConfigurable.getClassifierTemplate() classifierToUse = weka.classifiers.Classifier.makeCopy(classifierToUse) } double hS = Double.parseDouble(hSize) hS /= 100 int sS = Integer.parseInt(sSize) int nS = Integer.parseInt(nSteps) int numInHoldout = hS * insts.numInstances() Instances holdoutI = new Instances(insts, numInHoldout) for (int i = insts.numInstances() - numInHoldout; i < insts.numInstances(); i++) { holdoutI.add(insts.instance(i)) } String classifierSetUpString = classifierToUse.class.toString() + " " if (classifierToUse instanceof OptionHandler) { classifierSetUpString += Utils.joinOptions(((OptionHandler)classifierToUse).getOptions()) } if (m_log != null) { m_log.logMessage("[LearningCurve] Using classifier " + classifierSetUpString) } // create the instances structure to hold the learning curve results Attribute setSize = new Attribute("NumInstances") Attribute aucA = new Attribute("PercentCorrect") FastVector atts = new FastVector() atts.addElement(setSize) atts.addElement(aucA) // The preceeding "__" tells the DataVisualizer to connect the points with lines Instances learnCInstances = new Instances("__Learning curve: " + classifierSetUpString, atts, 0) boolean done = false Instances training = new Instances(insts, 0) for (int i = 0; i < nS; i++) { if (m_log != null) { m_log.statusMessage("LearningCurve\$"+hashCode()+"|Processing set "+(i+1)) } int numInThisStep = ((i + 1) * sS) if (numInThisStep >= (insts.numInstances() - numInHoldout)) { numInThisStep = (insts.numInstances() - numInHoldout) done = true } for (int k = (i * sS); k < numInThisStep; k++) { training.add(insts.instance(k)) } // train on this set Classifier newModel = Classifier.makeCopies(classifierToUse, 1)[0] newModel.buildClassifier(training) Evaluation eval = new Evaluation(holdoutI) eval.evaluateModel(newModel, holdoutI) double pc = (1.0 - eval.errorRate()) * 100.0 //double auc = 1.0 - eval.errorRate(); buff.append(""+numInThisStep+","+pc+"\n") //System.err.println(""+numInThisStep+","+auc+"\n") Instance newInst = new Instance(2) newInst.setValue(0, (double)numInThisStep) newInst.setValue(1, pc) learnCInstances.add(newInst) if (done) { break } } if (m_log != null) { m_log.statusMessage("LearningCurve\$"+hashCode()+"|Finished.") } //System.err.println(buff.toString()) m_helper.notifyTextListeners(new TextEvent(this, buff.toString(), "learning curve")) m_helper.notifyDataSourceListeners(new DataSetEvent(this, learnCInstances)) } void acceptTestSet(TestSetEvent e) { } void acceptDataSet(DataSetEvent e) { } void acceptInstance(InstanceEvent e) { } void acceptText(TextEvent e) { } void acceptClassifier(BatchClassifierEvent e) { } void acceptClassifier(IncrementalClassifierEvent e) { } void acceptClusterer(BatchClustererEvent e) { } void acceptGraph(GraphEvent e) { } void acceptDataPoint(ChartEvent e) { } void acceptDataSet(ThresholdDataEvent e) { } void acceptDataSet(VisualizableErrorEvent e) { } void acceptConfiguration(ConfigurationEvent e) { } }
, multiple selections available,