From 2a0cb435d7084cdbafa355503a7d263ebd3d8b30 Mon Sep 17 00:00:00 2001
From: Matt Ward <ward.matt@gmail.com>
Date: Wed, 12 Feb 2014 21:43:40 +0000
Subject: [PATCH] Implement EnvDTE.CodeInterface.AddFunction()

---
 .../Src/Refactoring/CSharpCodeGenerator.cs    | 14 ++++++++
 .../Project/Src/EnvDTE/CodeInterface.cs       | 35 +++++++++++++++++--
 .../Project/Src/ICodeGenerator.cs             |  1 +
 .../Project/Src/ThreadSafeCodeGenerator.cs    |  5 +++
 .../Base/Project/Refactoring/CodeGenerator.cs |  5 +++
 5 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/src/AddIns/BackendBindings/CSharpBinding/Project/Src/Refactoring/CSharpCodeGenerator.cs b/src/AddIns/BackendBindings/CSharpBinding/Project/Src/Refactoring/CSharpCodeGenerator.cs
index 512129cbd0..a14c7bf49b 100644
--- a/src/AddIns/BackendBindings/CSharpBinding/Project/Src/Refactoring/CSharpCodeGenerator.cs
+++ b/src/AddIns/BackendBindings/CSharpBinding/Project/Src/Refactoring/CSharpCodeGenerator.cs
@@ -159,6 +159,20 @@ namespace CSharpBinding.Refactoring
 			}
 		}
 		
+		public override void AddMethodAtStart(ITypeDefinition declaringType, Accessibility accessibility, IType returnType, string name)
+		{
+			SDRefactoringContext context = declaringType.CreateRefactoringContext();
+			var typeDecl = context.GetNode<TypeDeclaration>();
+			using (var script = context.StartScript()) {
+				var astBuilder = context.CreateTypeSystemAstBuilder(typeDecl.FirstChild);
+				var methodDecl = new MethodDeclaration();
+				methodDecl.Name = name;
+				methodDecl.ReturnType = astBuilder.ConvertType(context.Compilation.Import(returnType));
+				
+				script.AddTo(typeDecl, methodDecl);
+			}
+		}
+		
 		public override void ChangeAccessibility(IEntity entity, Accessibility newAccessiblity)
 		{
 			// TODO script.ChangeModifiers(...)
diff --git a/src/AddIns/Misc/PackageManagement/Project/Src/EnvDTE/CodeInterface.cs b/src/AddIns/Misc/PackageManagement/Project/Src/EnvDTE/CodeInterface.cs
index 83bef5fb5b..51eaf7f4dc 100644
--- a/src/AddIns/Misc/PackageManagement/Project/Src/EnvDTE/CodeInterface.cs
+++ b/src/AddIns/Misc/PackageManagement/Project/Src/EnvDTE/CodeInterface.cs
@@ -20,6 +20,7 @@ using System;
 using System.Linq;
 using System.Text;
 using ICSharpCode.NRefactory.TypeSystem;
+using ICSharpCode.NRefactory.TypeSystem.Implementation;
 using ICSharpCode.SharpDevelop.Dom;
 
 namespace ICSharpCode.PackageManagement.EnvDTE
@@ -37,9 +38,39 @@ namespace ICSharpCode.PackageManagement.EnvDTE
 		
 		public global::EnvDTE.CodeFunction AddFunction(string name, global::EnvDTE.vsCMFunction kind, object type, object Position = null, global::EnvDTE.vsCMAccess Access = global::EnvDTE.vsCMAccess.vsCMAccessPublic)
 		{
-		//	var codeGenerator = new ClassCodeGenerator(Class);
-		//	return codeGenerator.AddPublicMethod(name, (string)type);
+			IType returnType = GetMethodReturnType((string)type);
+			
+			context.CodeGenerator.AddMethodAtStart(typeDefinition, Access.ToAccessibility(), returnType, name);
+			
+			ReloadTypeDefinition();
+			
+			IMethod method = typeDefinition.Methods.FirstOrDefault(f => f.Name == name);
+			if (method != null) {
+				return new CodeFunction(context, method);
+			}
 			return null;
 		}
+		
+		IType GetMethodReturnType(string typeName)
+		{
+			var fullTypeName = new FullTypeName(typeName);
+			
+			IType type = typeDefinition.Compilation.FindType(fullTypeName);
+			if (type != null) {
+				return type;
+			}
+			
+			return new UnknownType(fullTypeName);
+		}
+		
+		void ReloadTypeDefinition()
+		{
+			ICompilation compilation = context.DteProject.GetCompilationUnit(typeDefinition.BodyRegion.FileName);
+			
+			ITypeDefinition matchedTypeDefinition = compilation.MainAssembly.GetTypeDefinition(typeDefinition.FullTypeName);
+			if (matchedTypeDefinition != null) {
+				typeDefinition = matchedTypeDefinition;
+			}
+		}
 	}
 }
diff --git a/src/AddIns/Misc/PackageManagement/Project/Src/ICodeGenerator.cs b/src/AddIns/Misc/PackageManagement/Project/Src/ICodeGenerator.cs
index 16f1acc9b7..98d5b8267c 100644
--- a/src/AddIns/Misc/PackageManagement/Project/Src/ICodeGenerator.cs
+++ b/src/AddIns/Misc/PackageManagement/Project/Src/ICodeGenerator.cs
@@ -27,5 +27,6 @@ namespace ICSharpCode.PackageManagement
 		void AddImport(FileName fileName, string name);
 		void MakePartial(ITypeDefinition typeDefinition);
 		void AddFieldAtStart(ITypeDefinition typeDefinition, Accessibility accessibility, IType fieldType, string name);
+		void AddMethodAtStart(ITypeDefinition typeDefinition, Accessibility accessibility, IType returnType, string name);
 	}
 }
diff --git a/src/AddIns/Misc/PackageManagement/Project/Src/ThreadSafeCodeGenerator.cs b/src/AddIns/Misc/PackageManagement/Project/Src/ThreadSafeCodeGenerator.cs
index f4025132be..d7e5725701 100644
--- a/src/AddIns/Misc/PackageManagement/Project/Src/ThreadSafeCodeGenerator.cs
+++ b/src/AddIns/Misc/PackageManagement/Project/Src/ThreadSafeCodeGenerator.cs
@@ -54,5 +54,10 @@ namespace ICSharpCode.PackageManagement
 		{
 			InvokeIfRequired(() => codeGenerator.AddFieldAtStart(typeDefinition, accessibility, fieldType, name));
 		}
+		
+		public void AddMethodAtStart(ITypeDefinition typeDefinition, Accessibility accessibility, IType returnType, string name)
+		{
+			InvokeIfRequired(() => codeGenerator.AddMethodAtStart(typeDefinition, accessibility, returnType, name));
+		}
 	}
 }
diff --git a/src/Main/Base/Project/Refactoring/CodeGenerator.cs b/src/Main/Base/Project/Refactoring/CodeGenerator.cs
index 5ff619310e..e9bf579bfc 100644
--- a/src/Main/Base/Project/Refactoring/CodeGenerator.cs
+++ b/src/Main/Base/Project/Refactoring/CodeGenerator.cs
@@ -93,6 +93,11 @@ namespace ICSharpCode.SharpDevelop.Refactoring
 			throw new NotSupportedException("Feature not supported!");
 		}
 		
+		public virtual void AddMethodAtStart(ITypeDefinition declaringType, Accessibility accessibility, IType returnType, string name)
+		{
+			throw new NotSupportedException("Feature not supported!");
+		}
+		
 		public virtual void ChangeAccessibility(IEntity entity, Accessibility newAccessiblity)
 		{
 			throw new NotSupportedException("Feature not supported!");